diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index cc9761b1..76ded47e 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -13,7 +13,7 @@ # limitations under the License. from enum import Enum, unique -from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union from datasets import concatenate_datasets, interleave_datasets @@ -30,6 +30,9 @@ if TYPE_CHECKING: logger = get_logger(__name__) +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] + + @unique class Role(str, Enum): USER = "user" @@ -39,13 +42,6 @@ class Role(str, Enum): OBSERVATION = "observation" -def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: - max_target_len = int(max_len * (target_len / (source_len + target_len))) - max_target_len = max(max_target_len, reserved_label_len) - max_source_len = max_len - min(max_target_len, target_len) - return max_source_len, max_target_len - - def merge_dataset( all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 88ebf682..c1653a76 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -16,97 +16,10 @@ import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union - -SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] - - -DEFAULT_TOOL_PROMPT = ( - "You have access to the following tools:\n{tool_text}" - "Use the following format if using a tool:\n" - "```\n" - "Action: tool name (one of [{tool_names}]).\n" - "Action Input: the input to the tool, in a JSON format representing the kwargs " - """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n""" - "```\n" -) - - -GLM4_TOOL_PROMPT = ( - "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," - "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" -) - - -def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: - tool_text = "" - tool_names = [] - for tool in tools: - param_text = "" - for name, param in tool["parameters"]["properties"].items(): - required = ", required" if name in tool["parameters"].get("required", []) else "" - enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" - items = ( - ", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else "" - ) - param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( - name=name, - type=param.get("type", ""), - required=required, - desc=param.get("description", ""), - enum=enum, - items=items, - ) - - tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( - name=tool["name"], desc=tool.get("description", ""), args=param_text - ) - tool_names.append(tool["name"]) - - return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) - - -def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: - regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) - action_match: List[Tuple[str, str]] = re.findall(regex, content) - if not action_match: - return content - - results = [] - for match in action_match: - tool_name = match[0].strip() - tool_input = match[1].strip().strip('"').strip("```") - try: - arguments = json.loads(tool_input) - results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) - except json.JSONDecodeError: - return content - - return results - - -def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str: - tool_text = "" - for tool in tools: - tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( - name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) - ) - - return GLM4_TOOL_PROMPT.format(tool_text=tool_text) - - -def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: - if "\n" not in content: - return content - - tool_name, tool_input = content.split("\n", maxsplit=1) - try: - arguments = json.loads(tool_input) - except json.JSONDecodeError: - return content - - return [(tool_name, json.dumps(arguments, ensure_ascii=False))] +from .data_utils import SLOTS +from .tool_utils import DefaultToolUtils, GLM4ToolUtils @dataclass @@ -168,15 +81,12 @@ class StringFormatter(Formatter): @dataclass class FunctionFormatter(Formatter): def __post_init__(self): - has_name, has_args = False, False - for slot in filter(lambda s: isinstance(s, str), self.slots): - if "{{name}}" in slot: - has_name = True - if "{{arguments}}" in slot: - has_args = True - - if not has_name or not has_args: - raise ValueError("Name and arguments placeholders are required in the function formatter.") + if self.tool_format == "default": + self.slots = DefaultToolUtils.get_function_slots() + self.slots + elif self.tool_format == "glm4": + self.slots = GLM4ToolUtils.get_function_slots() + self.slots + else: + raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") @@ -210,11 +120,11 @@ class FunctionFormatter(Formatter): class ToolFormatter(Formatter): def __post_init__(self): if self.tool_format == "default": - self._tool_formatter = default_tool_formatter - self._tool_extractor = default_tool_extractor + self._tool_formatter = DefaultToolUtils.tool_formatter + self._tool_extractor = DefaultToolUtils.tool_extractor elif self.tool_format == "glm4": - self._tool_formatter = glm4_tool_formatter - self._tool_extractor = glm4_tool_extractor + self._tool_formatter = GLM4ToolUtils.tool_formatter + self._tool_extractor = GLM4ToolUtils.tool_extractor else: raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 219ab353..7ba05e23 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen if TYPE_CHECKING: @@ -55,12 +55,8 @@ def _encode_feedback_example( else: kl_messages = prompt + [kl_response[1]] - prompt_ids, response_ids = template.encode_oneturn( - tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) - _, kl_response_ids = template.encode_oneturn( - tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools) + _, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools) if template.efficient_eos: response_ids += [tokenizer.eos_token_id] @@ -70,6 +66,12 @@ def _encode_feedback_example( image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + # do not consider the kl_response + source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len) + prompt_ids = prompt_ids[:source_len] + response_ids = response_ids[:target_len] + kl_response_ids = kl_response_ids[:target_len] + input_ids = prompt_ids + response_ids labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids kl_input_ids = prompt_ids + kl_response_ids diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index b2939348..c6001e6e 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen if TYPE_CHECKING: @@ -44,12 +44,8 @@ def _encode_pairwise_example( chosen_messages = prompt + [response[0]] rejected_messages = prompt + [response[1]] - prompt_ids, chosen_ids = template.encode_oneturn( - tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) - _, rejected_ids = template.encode_oneturn( - tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) + _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools) if template.efficient_eos: chosen_ids += [tokenizer.eos_token_id] @@ -59,6 +55,13 @@ def _encode_pairwise_example( image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + source_len, target_len = infer_seqlen( + len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len + ) # consider the response is more important + prompt_ids = prompt_ids[:source_len] + chosen_ids = chosen_ids[:target_len] + rejected_ids = rejected_ids[:target_len] + chosen_input_ids = prompt_ids + chosen_ids chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids rejected_input_ids = prompt_ids + rejected_ids diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 93df0cd5..455908ae 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import bisect -from typing import TYPE_CHECKING, List, Sequence +from typing import TYPE_CHECKING, List, Sequence, Tuple from ...extras.packages import is_pillow_available @@ -76,3 +76,16 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> """ image_seq_length = getattr(processor, "image_seq_length") return [0] * image_seq_length + [1] * (input_len - image_seq_length) + + +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: + if target_len * 2 < cutoff_len: # truncate source + max_target_len = cutoff_len + elif source_len * 2 < cutoff_len: # truncate target + max_target_len = cutoff_len - source_len + else: # truncate both + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + + new_target_len = min(max_target_len, target_len) + new_source_len = max(cutoff_len - new_target_len, 0) + return new_source_len, new_target_len diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index eb5ffb1a..b283542d 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen if TYPE_CHECKING: @@ -51,10 +51,17 @@ def _encode_supervised_example( input_ids += [image_token_id] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - encoded_pairs = template.encode_multiturn( - tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) + total_length = 1 if template.efficient_eos else 0 for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= data_args.cutoff_len: + break + + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + if data_args.train_on_prompt: source_mask = source_ids elif turn_idx != 0 and template.efficient_eos: diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 75ad4d51..b3fc85c9 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.logging import get_logger from ..data_utils import Role -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen if TYPE_CHECKING: @@ -47,9 +47,7 @@ def _encode_unsupervised_example( else: messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] - input_ids, labels = template.encode_oneturn( - tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len - ) + input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) if template.efficient_eos: labels += [tokenizer.eos_token_id] @@ -57,6 +55,9 @@ def _encode_unsupervised_example( image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids + source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len) + input_ids = input_ids[:source_len] + labels = labels[:target_len] return input_ids, labels diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 53f16df4..aefd5195 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from ..extras.logging import get_logger -from .data_utils import Role, infer_max_len +from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter @@ -48,36 +48,33 @@ class Template: def encode_oneturn( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - cutoff_len: int = 1_000_000, - reserved_label_len: int = 1, ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ - encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + encoded_messages = self._encode(tokenizer, messages, system, tools) prompt_ids = [] - for query_ids, resp_ids in encoded_pairs[:-1]: - prompt_ids += query_ids + resp_ids - prompt_ids = prompt_ids + encoded_pairs[-1][0] - answer_ids = encoded_pairs[-1][1] + for encoded_ids in encoded_messages[:-1]: + prompt_ids += encoded_ids + + answer_ids = encoded_messages[-1] return prompt_ids, answer_ids def encode_multiturn( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - cutoff_len: int = 1_000_000, - reserved_label_len: int = 1, - ) -> Sequence[Tuple[List[int], List[int]]]: + ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + encoded_messages = self._encode(tokenizer, messages, system, tools) + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: r""" @@ -88,16 +85,14 @@ class Template: def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: Optional[str], tools: Optional[str], - cutoff_len: int, - reserved_label_len: int, - ) -> Sequence[Tuple[List[int], List[int]]]: + ) -> List[List[int]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: system + query resp - Turn t: sep + query resp + Turn 0: prefix + system + query resp + Turn t: sep + query resp """ system = system or self.default_system encoded_messages = [] @@ -106,10 +101,9 @@ class Template: if i == 0: elements += self.format_prefix.apply() - - if i == 0 and (system or tools): - tool_text = self.format_tools.apply(content=tools)[0] if tools else "" - elements += self.format_system.apply(content=(system + tool_text)) + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) if i > 0 and i % 2 == 0: elements += self.format_separator.apply() @@ -127,11 +121,9 @@ class Template: encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + return encoded_messages - def _convert_elements_to_ids( - self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] - ) -> List[int]: + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: r""" Converts elements to token ids. """ @@ -152,60 +144,32 @@ class Template: return token_ids - def _make_pairs( - self, - encoded_messages: Sequence[List[int]], - cutoff_len: int, - reserved_label_len: int, - ) -> Sequence[Tuple[List[int], List[int]]]: - encoded_pairs = [] - total_length = 0 - for i in range(0, len(encoded_messages), 2): - if total_length >= cutoff_len: - break - - max_source_len, max_target_len = infer_max_len( - source_len=len(encoded_messages[i]), - target_len=len(encoded_messages[i + 1]), - max_len=(cutoff_len - total_length), - reserved_label_len=reserved_label_len, - ) - source_ids = encoded_messages[i][:max_source_len] - target_ids = encoded_messages[i + 1][:max_target_len] - total_length += len(source_ids) + len(target_ids) - encoded_pairs.append((source_ids, target_ids)) - - return encoded_pairs - @dataclass class Llama2Template(Template): def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: List[Dict[str, str]], + messages: Sequence[Dict[str, str]], system: str, tools: str, - cutoff_len: int, - reserved_label_len: int, - ) -> Sequence[Tuple[List[int], List[int]]]: + ) -> List[List[int]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: system + query resp - Turn t: sep + query resp + Turn 0: prefix + system + query resp + Turn t: sep + query resp """ system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): elements = [] + system_text = "" if i == 0: elements += self.format_prefix.apply() - - system_text = "" - if i == 0 and (system or tools): - tool_text = self.format_tools.apply(content=tools)[0] if tools else "" - system_text = self.format_system.apply(content=(system + tool_text))[0] + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] if i > 0 and i % 2 == 0: elements += self.format_separator.apply() @@ -223,7 +187,7 @@ class Llama2Template(Template): encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) - return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) + return encoded_messages TEMPLATES: Dict[str, Template] = {} @@ -240,7 +204,7 @@ def _register_template( format_separator: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None, default_system: str = "", - stop_words: List[str] = [], + stop_words: Sequence[str] = [], image_token: str = "", efficient_eos: bool = False, replace_eos: bool = False, @@ -275,9 +239,7 @@ def _register_template( template_class = Llama2Template if name.startswith("llama2") else Template default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) - default_function_formatter = FunctionFormatter( - slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots - ) + default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter() @@ -390,7 +352,9 @@ def get_template_and_fix_tokenizer( if tool_format is not None: logger.info("Using tool format: {}.".format(tool_format)) + eos_slots = [] if template.efficient_eos else [{"eos_token"}] template.format_tools = ToolFormatter(tool_format=tool_format) + template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) stop_words = template.stop_words if template.replace_eos: @@ -506,10 +470,11 @@ _register_template( format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), - format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), format_observation=StringFormatter( slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), + format_tools=ToolFormatter(tool_format="glm4"), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, @@ -603,16 +568,15 @@ _register_template( _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), - format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), default_system=( "You are an AI programming assistant, utilizing the Deepseek Coder model, " "developed by Deepseek Company, and you only answer questions related to computer science. " "For politically sensitive questions, security and privacy issues, " "and other non-computer science questions, you will refuse to answer\n" ), - stop_words=["<|EOT|>"], - efficient_eos=True, ) @@ -662,7 +626,7 @@ _register_template( format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_assistant=StringFormatter(slots=["\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), - format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), format_tools=ToolFormatter(tool_format="glm4"), format_prefix=EmptyFormatter(slots=["[gMASK]"]), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py new file mode 100644 index 00000000..ac5565d5 --- /dev/null +++ b/src/llamafactory/data/tool_utils.py @@ -0,0 +1,140 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +from .data_utils import SLOTS + + +DEFAULT_TOOL_PROMPT = ( + "You have access to the following tools:\n{tool_text}" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [{tool_names}]).\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n""" + "```\n" +) + + +GLM4_TOOL_PROMPT = ( + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" +) + + +@dataclass +class ToolUtils(ABC): + @staticmethod + @abstractmethod + def get_function_slots() -> SLOTS: ... + + @staticmethod + @abstractmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... + + @staticmethod + @abstractmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... + + +class DefaultToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["Action: {{name}}\nAction Input: {{arguments}}\n"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + tool_names = [] + for tool in tools: + param_text = "" + for name, param in tool["parameters"]["properties"].items(): + required, enum, items = "", "", "" + if name in tool["parameters"].get("required", []): + required = ", required" + + if param.get("enum", None): + enum = ", should be one of [{}]".format(", ".join(param["enum"])) + + if param.get("items", None): + items = ", where each item should be {}".format(param["items"].get("type", "")) + + param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( + name=name, + type=param.get("type", ""), + required=required, + desc=param.get("description", ""), + enum=enum, + items=items, + ) + + tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( + name=tool["name"], desc=tool.get("description", ""), args=param_text + ) + tool_names.append(tool["name"]) + + return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) + action_match: List[Tuple[str, str]] = re.findall(regex, content) + if not action_match: + return content + + results = [] + for match in action_match: + tool_name = match[0].strip() + tool_input = match[1].strip().strip('"').strip("```") + try: + arguments = json.loads(tool_input) + results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) + except json.JSONDecodeError: + return content + + return results + + +class GLM4ToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["{{name}}\n{{arguments}}"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) + ) + + return GLM4_TOOL_PROMPT.format(tool_text=tool_text) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + if "\n" not in content: + return content + + tool_name, tool_input = content.split("\n", maxsplit=1) + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index dad13820..880be84a 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -45,10 +45,6 @@ class DataArguments: default=1024, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, ) - reserved_label_len: int = field( - default=1, - metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."}, - ) train_on_prompt: bool = field( default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}, @@ -111,9 +107,6 @@ class DataArguments: ) def __post_init__(self): - if self.reserved_label_len >= self.cutoff_len: - raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.") - if self.streaming and self.val_size > 1e-6 and self.val_size < 1: raise ValueError("Streaming mode should have an integer val size.") diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 37b21dc5..1845df24 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -28,7 +28,7 @@ def test_string_formatter(): def test_function_formatter(): - formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"]) + formatter = FunctionFormatter(slots=[], tool_format="default") tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) assert formatter.apply(content=tool_calls) == [ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""" @@ -36,7 +36,7 @@ def test_function_formatter(): def test_multi_function_formatter(): - formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"]) + formatter = FunctionFormatter(slots=[], tool_format="default") tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2) assert formatter.apply(content=tool_calls) == [ """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", diff --git a/tests/data/test_processor.py b/tests/data/test_processor.py new file mode 100644 index 00000000..fa8f7172 --- /dev/null +++ b/tests/data/test_processor.py @@ -0,0 +1,32 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import pytest + +from llamafactory.data.processors.processor_utils import infer_seqlen + + +@pytest.mark.parametrize( + "test_input,test_output", + [ + ((3000, 2000, 1000), (600, 400)), + ((2000, 3000, 1000), (400, 600)), + ((1000, 100, 1000), (900, 100)), + ((100, 1000, 1000), (100, 900)), + ], +) +def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]): + assert test_output == infer_seqlen(*test_input) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 9d73c116..e4728a84 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -21,15 +21,60 @@ from llamafactory.data import get_template_and_fix_tokenizer TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") +MESSAGES = [ + {"role": "user", "content": "How are you"}, + {"role": "assistant", "content": "I am fine!"}, + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "很高兴认识你!"}, +] + + +def test_encode_oneturn(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + template = get_template_and_fix_tokenizer(tokenizer, name="llama3") + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) + assert tokenizer.decode(prompt_ids) == ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>" + + +def test_encode_multiturn(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + template = get_template_and_fix_tokenizer(tokenizer, name="llama3") + encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) + assert tokenizer.decode(encoded_pairs[0][0]) == ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>" + assert tokenizer.decode(encoded_pairs[1][0]) == ( + "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>" + def test_jinja_template(): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) get_template_and_fix_tokenizer(tokenizer, name="llama3") assert tokenizer.chat_template != ref_tokenizer.chat_template + assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) - messages = [ - {"role": "user", "content": "hi!"}, - {"role": "assistant", "content": "hello there"}, - ] - assert tokenizer.apply_chat_template(messages) == ref_tokenizer.apply_chat_template(messages) + +def test_qwen_template(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") + template = get_template_and_fix_tokenizer(tokenizer, name="qwen") + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) + assert tokenizer.decode(prompt_ids) == ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\nHow are you<|im_end|>\n" + "<|im_start|>assistant\nI am fine!<|im_end|>\n" + "<|im_start|>user\n你好<|im_end|>\n" + "<|im_start|>assistant\n" + ) + assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"