Deprecate reserved_label_len arg
This commit is contained in:
hiyouga 2024-07-01 01:19:27 +08:00
parent d4e2af1fa4
commit 1771251ce3
13 changed files with 329 additions and 223 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
from datasets import concatenate_datasets, interleave_datasets from datasets import concatenate_datasets, interleave_datasets
@ -30,6 +30,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@unique @unique
class Role(str, Enum): class Role(str, Enum):
USER = "user" USER = "user"
@ -39,13 +42,6 @@ class Role(str, Enum):
OBSERVATION = "observation" OBSERVATION = "observation"
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments", data_args: "DataArguments",

View File

@ -16,97 +16,10 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
from .data_utils import SLOTS
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] from .tool_utils import DefaultToolUtils, GLM4ToolUtils
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
"```\n"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
items = (
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
)
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
results = []
for match in action_match:
tool_name = match[0].strip()
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content
return results
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
if "\n" not in content:
return content
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
@dataclass @dataclass
@ -168,15 +81,12 @@ class StringFormatter(Formatter):
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
has_name, has_args = False, False if self.tool_format == "default":
for slot in filter(lambda s: isinstance(s, str), self.slots): self.slots = DefaultToolUtils.get_function_slots() + self.slots
if "{{name}}" in slot: elif self.tool_format == "glm4":
has_name = True self.slots = GLM4ToolUtils.get_function_slots() + self.slots
if "{{arguments}}" in slot: else:
has_args = True raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
if not has_name or not has_args:
raise ValueError("Name and arguments placeholders are required in the function formatter.")
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
@ -210,11 +120,11 @@ class FunctionFormatter(Formatter):
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format == "default": if self.tool_format == "default":
self._tool_formatter = default_tool_formatter self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = default_tool_extractor self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4": elif self.tool_format == "glm4":
self._tool_formatter = glm4_tool_formatter self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = glm4_tool_extractor self._tool_extractor = GLM4ToolUtils.tool_extractor
else: else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
@ -55,12 +55,8 @@ def _encode_feedback_example(
else: else:
kl_messages = prompt + [kl_response[1]] kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn( prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len _, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
)
_, kl_response_ids = template.encode_oneturn(
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos: if template.efficient_eos:
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
@ -70,6 +66,12 @@ def _encode_feedback_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
# do not consider the kl_response
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_response_ids = kl_response_ids[:target_len]
input_ids = prompt_ids + response_ids input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids kl_input_ids = prompt_ids + kl_response_ids

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
@ -44,12 +44,8 @@ def _encode_pairwise_example(
chosen_messages = prompt + [response[0]] chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]] rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn( prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos: if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id] chosen_ids += [tokenizer.eos_token_id]
@ -59,6 +55,13 @@ def _encode_pairwise_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
) # consider the response is more important
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids rejected_input_ids = prompt_ids + rejected_ids

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import bisect import bisect
from typing import TYPE_CHECKING, List, Sequence from typing import TYPE_CHECKING, List, Sequence, Tuple
from ...extras.packages import is_pillow_available from ...extras.packages import is_pillow_available
@ -76,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
""" """
image_seq_length = getattr(processor, "image_seq_length") image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length) return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
max_target_len = cutoff_len - source_len
else: # truncate both
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
new_source_len = max(cutoff_len - new_target_len, 0)
return new_source_len, new_target_len

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,10 +51,17 @@ def _encode_supervised_example(
input_ids += [image_token_id] * getattr(processor, "image_seq_length") input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn( encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len total_length = 1 if template.efficient_eos else 0
)
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= data_args.cutoff_len:
break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos: elif turn_idx != 0 and template.efficient_eos:

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
@ -47,9 +47,7 @@ def _encode_unsupervised_example(
else: else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
input_ids, labels = template.encode_oneturn( input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
@ -57,6 +55,9 @@ def _encode_unsupervised_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels return input_ids, labels

View File

@ -16,7 +16,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role, infer_max_len from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
@ -48,36 +48,33 @@ class Template:
def encode_oneturn( def encode_oneturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
""" """
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = [] prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]: for encoded_ids in encoded_messages[:-1]:
prompt_ids += query_ids + resp_ids prompt_ids += encoded_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1] answer_ids = encoded_messages[-1]
return prompt_ids, answer_ids return prompt_ids, answer_ids
def encode_multiturn( def encode_multiturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: int = 1_000_000, ) -> List[Tuple[List[int], List[int]]]:
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
""" """
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r""" r"""
@ -88,16 +85,14 @@ class Template:
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
cutoff_len: int, ) -> List[List[int]]:
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp Turn 0: prefix + system + query resp
Turn t: sep + query resp Turn t: sep + query resp
""" """
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
@ -106,10 +101,9 @@ class Template:
if i == 0: if i == 0:
elements += self.format_prefix.apply() elements += self.format_prefix.apply()
if system or tools:
if i == 0 and (system or tools): tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text))
elements += self.format_system.apply(content=(system + tool_text))
if i > 0 and i % 2 == 0: if i > 0 and i % 2 == 0:
elements += self.format_separator.apply() elements += self.format_separator.apply()
@ -127,11 +121,9 @@ class Template:
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) return encoded_messages
def _convert_elements_to_ids( def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r""" r"""
Converts elements to token ids. Converts elements to token ids.
""" """
@ -152,60 +144,32 @@ class Template:
return token_ids return token_ids
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: str, system: str,
tools: str, tools: str,
cutoff_len: int, ) -> List[List[int]]:
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp Turn 0: prefix + system + query resp
Turn t: sep + query resp Turn t: sep + query resp
""" """
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = []
system_text = ""
if i == 0: if i == 0:
elements += self.format_prefix.apply() elements += self.format_prefix.apply()
if system or tools:
system_text = "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
if i == 0 and (system or tools): system_text = self.format_system.apply(content=(system + tool_text))[0]
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if i > 0 and i % 2 == 0: if i > 0 and i % 2 == 0:
elements += self.format_separator.apply() elements += self.format_separator.apply()
@ -223,7 +187,7 @@ class Llama2Template(Template):
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) return encoded_messages
TEMPLATES: Dict[str, Template] = {} TEMPLATES: Dict[str, Template] = {}
@ -240,7 +204,7 @@ def _register_template(
format_separator: Optional["Formatter"] = None, format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: List[str] = [], stop_words: Sequence[str] = [],
image_token: str = "<image>", image_token: str = "<image>",
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
@ -275,9 +239,7 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter( default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
)
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter()
@ -390,7 +352,9 @@ def get_template_and_fix_tokenizer(
if tool_format is not None: if tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format)) logger.info("Using tool format: {}.".format(tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format) template.format_tools = ToolFormatter(tool_format=tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:
@ -506,10 +470,11 @@ _register_template(
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
), ),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"], stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True, efficient_eos=True,
@ -603,16 +568,15 @@ _register_template(
_register_template( _register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]), format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=( default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, " "You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. " "developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, " "For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer\n" "and other non-computer science questions, you will refuse to answer\n"
), ),
stop_words=["<|EOT|>"],
efficient_eos=True,
) )
@ -662,7 +626,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]), format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"), format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]), format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),

View File

@ -0,0 +1,140 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from .data_utils import SLOTS
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
"```\n"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
@dataclass
class ToolUtils(ABC):
@staticmethod
@abstractmethod
def get_function_slots() -> SLOTS: ...
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
class DefaultToolUtils(ToolUtils):
@staticmethod
def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
for tool in tools:
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", ""
if name in tool["parameters"].get("required", []):
required = ", required"
if param.get("enum", None):
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
if param.get("items", None):
items = ", where each item should be {}".format(param["items"].get("type", ""))
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name,
type=param.get("type", ""),
required=required,
desc=param.get("description", ""),
enum=enum,
items=items,
)
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
name=tool["name"], desc=tool.get("description", ""), args=param_text
)
tool_names.append(tool["name"])
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
results = []
for match in action_match:
tool_name = match[0].strip()
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content
return results
class GLM4ToolUtils(ToolUtils):
@staticmethod
def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"]
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
if "\n" not in content:
return content
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]

View File

@ -45,10 +45,6 @@ class DataArguments:
default=1024, default=1024,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
) )
reserved_label_len: int = field(
default=1,
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
)
train_on_prompt: bool = field( train_on_prompt: bool = field(
default=False, default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."}, metadata={"help": "Whether to disable the mask on the prompt or not."},
@ -111,9 +107,6 @@ class DataArguments:
) )
def __post_init__(self): def __post_init__(self):
if self.reserved_label_len >= self.cutoff_len:
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1: if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.") raise ValueError("Streaming mode should have an integer val size.")

View File

@ -28,7 +28,7 @@ def test_string_formatter():
def test_function_formatter(): def test_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"]) formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""" """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
@ -36,7 +36,7 @@ def test_function_formatter():
def test_multi_function_formatter(): def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"]) formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2) tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",

View File

@ -0,0 +1,32 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import pytest
from llamafactory.data.processors.processor_utils import infer_seqlen
@pytest.mark.parametrize(
"test_input,test_output",
[
((3000, 2000, 1000), (600, 400)),
((2000, 3000, 1000), (400, 600)),
((1000, 100, 1000), (900, 100)),
((100, 1000, 1000), (100, 900)),
],
)
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
assert test_output == infer_seqlen(*test_input)

View File

@ -21,15 +21,60 @@ from llamafactory.data import get_template_and_fix_tokenizer
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "很高兴认识你!"},
]
def test_encode_oneturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert tokenizer.decode(prompt_ids) == (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>"
def test_encode_multiturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
assert tokenizer.decode(encoded_pairs[0][0]) == (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>"
assert tokenizer.decode(encoded_pairs[1][0]) == (
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>"
def test_jinja_template(): def test_jinja_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
get_template_and_fix_tokenizer(tokenizer, name="llama3") get_template_and_fix_tokenizer(tokenizer, name="llama3")
assert tokenizer.chat_template != ref_tokenizer.chat_template assert tokenizer.chat_template != ref_tokenizer.chat_template
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
messages = [
{"role": "user", "content": "hi!"}, def test_qwen_template():
{"role": "assistant", "content": "hello there"}, tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
] template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
assert tokenizer.apply_chat_template(messages) == ref_tokenizer.apply_chat_template(messages) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert tokenizer.decode(prompt_ids) == (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"