Deprecate reserved_label_len arg
This commit is contained in:
hiyouga 2024-07-01 01:19:27 +08:00
parent d4e2af1fa4
commit 1771251ce3
13 changed files with 329 additions and 223 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
from datasets import concatenate_datasets, interleave_datasets
@ -30,6 +30,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@unique
class Role(str, Enum):
USER = "user"
@ -39,13 +42,6 @@ class Role(str, Enum):
OBSERVATION = "observation"
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",

View File

@ -16,97 +16,10 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
"```\n"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
items = (
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
)
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
results = []
for match in action_match:
tool_name = match[0].strip()
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content
return results
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
if "\n" not in content:
return content
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
@dataclass
@ -168,15 +81,12 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
has_name, has_args = False, False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if "{{name}}" in slot:
has_name = True
if "{{arguments}}" in slot:
has_args = True
if not has_name or not has_args:
raise ValueError("Name and arguments placeholders are required in the function formatter.")
if self.tool_format == "default":
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
@ -210,11 +120,11 @@ class FunctionFormatter(Formatter):
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self._tool_formatter = default_tool_formatter
self._tool_extractor = default_tool_extractor
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = glm4_tool_formatter
self._tool_extractor = glm4_tool_extractor
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
@ -55,12 +55,8 @@ def _encode_feedback_example(
else:
kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, kl_response_ids = template.encode_oneturn(
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
@ -70,6 +66,12 @@ def _encode_feedback_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
# do not consider the kl_response
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_response_ids = kl_response_ids[:target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
@ -44,12 +44,8 @@ def _encode_pairwise_example(
chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
@ -59,6 +55,13 @@ def _encode_pairwise_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
) # consider the response is more important
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids

View File

@ -13,7 +13,7 @@
# limitations under the License.
import bisect
from typing import TYPE_CHECKING, List, Sequence
from typing import TYPE_CHECKING, List, Sequence, Tuple
from ...extras.packages import is_pillow_available
@ -76,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
max_target_len = cutoff_len - source_len
else: # truncate both
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
new_source_len = max(cutoff_len - new_target_len, 0)
return new_source_len, new_target_len

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
@ -51,10 +51,17 @@ def _encode_supervised_example(
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= data_args.cutoff_len:
break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
@ -47,9 +47,7 @@ def _encode_unsupervised_example(
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
@ -57,6 +55,9 @@ def _encode_unsupervised_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels

View File

@ -16,7 +16,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger
from .data_utils import Role, infer_max_len
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
@ -48,36 +48,33 @@ class Template:
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
answer_ids = encoded_messages[-1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r"""
@ -88,16 +85,14 @@ class Template:
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
messages: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
@ -106,10 +101,9 @@ class Template:
if i == 0:
elements += self.format_prefix.apply()
if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
@ -127,11 +121,9 @@ class Template:
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
return encoded_messages
def _convert_elements_to_ids(
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
@ -152,60 +144,32 @@ class Template:
return token_ids
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
messages: Sequence[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0:
elements += self.format_prefix.apply()
system_text = ""
if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
@ -223,7 +187,7 @@ class Llama2Template(Template):
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
return encoded_messages
TEMPLATES: Dict[str, Template] = {}
@ -240,7 +204,7 @@ def _register_template(
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
stop_words: Sequence[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False,
replace_eos: bool = False,
@ -275,9 +239,7 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
)
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
@ -390,7 +352,9 @@ def get_template_and_fix_tokenizer(
if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
stop_words = template.stop_words
if template.replace_eos:
@ -506,10 +470,11 @@ _register_template(
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
@ -603,16 +568,15 @@ _register_template(
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n"
),
stop_words=["<|EOT|>"],
efficient_eos=True,
)
@ -662,7 +626,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),

View File

@ -0,0 +1,140 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from .data_utils import SLOTS
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
"```\n"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
@dataclass
class ToolUtils(ABC):
@staticmethod
@abstractmethod
def get_function_slots() -> SLOTS: ...
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
class DefaultToolUtils(ToolUtils):
@staticmethod
def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", ""
if name in tool["parameters"].get("required", []):
required = ", required"
if param.get("enum", None):
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
if param.get("items", None):
items = ", where each item should be {}".format(param["items"].get("type", ""))
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
results = []
for match in action_match:
tool_name = match[0].strip()
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content
return results
class GLM4ToolUtils(ToolUtils):
@staticmethod
def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"]
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
if "\n" not in content:
return content
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]

View File

@ -45,10 +45,6 @@ class DataArguments:
default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
reserved_label_len: int = field(
default=1,
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
)
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."},
@ -111,9 +107,6 @@ class DataArguments:
)
def __post_init__(self):
if self.reserved_label_len >= self.cutoff_len:
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")

View File

@ -28,7 +28,7 @@ def test_string_formatter():
def test_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
@ -36,7 +36,7 @@ def test_function_formatter():
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",

View File

@ -0,0 +1,32 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import pytest
from llamafactory.data.processors.processor_utils import infer_seqlen
@pytest.mark.parametrize(
"test_input,test_output",
[
((3000, 2000, 1000), (600, 400)),
((2000, 3000, 1000), (400, 600)),
((1000, 100, 1000), (900, 100)),
((100, 1000, 1000), (100, 900)),
],
)
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
assert test_output == infer_seqlen(*test_input)

View File

@ -21,15 +21,60 @@ from llamafactory.data import get_template_and_fix_tokenizer
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "很高兴认识你!"},
]
def test_encode_oneturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert tokenizer.decode(prompt_ids) == (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>"
def test_encode_multiturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
assert tokenizer.decode(encoded_pairs[0][0]) == (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>"
assert tokenizer.decode(encoded_pairs[1][0]) == (
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>"
def test_jinja_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
get_template_and_fix_tokenizer(tokenizer, name="llama3")
assert tokenizer.chat_template != ref_tokenizer.chat_template
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
messages = [
{"role": "user", "content": "hi!"},
{"role": "assistant", "content": "hello there"},
]
assert tokenizer.apply_chat_template(messages) == ref_tokenizer.apply_chat_template(messages)
def test_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert tokenizer.decode(prompt_ids) == (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"