diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 9e60175b..22a24339 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -54,7 +54,7 @@ class HuggingfaceEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" if self.can_generate else "right" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.model = load_model( self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) ) # must after fixing tokenizer to resize vocab diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 2626d612..f0d23676 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -59,7 +59,7 @@ class VllmEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.generating_args = generating_args.to_dict() engine_args = { diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index f44ef5de..8e7062db 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -148,7 +148,7 @@ def get_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, ) -> Union["Dataset", "IterableDataset"]: - template = get_template_and_fix_tokenizer(tokenizer, data_args.template) + template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b5bf688c..3a72a858 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -379,6 +379,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") def get_template_and_fix_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, + tool_format: Optional[str] = None, ) -> Template: if name is None: template = TEMPLATES["empty"] # placeholder @@ -386,6 +387,9 @@ def get_template_and_fix_tokenizer( template = TEMPLATES.get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) + + if tool_format: + template.format_tools = ToolFormatter(tool_format=tool_format) stop_words = template.stop_words if template.replace_eos: diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 39290e21..959742e3 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -29,6 +29,10 @@ class DataArguments: default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}, ) + tool_format: Optional[str] = field( + default=None, + metadata={"help": "Specifies the tool format template for function calling ."}, + ) dataset: Optional[str] = field( default=None, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 430eb0e6..a01e8a7e 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -111,9 +111,9 @@ def test_glm4_tool_formatter(): } ] assert formatter.apply(content=json.dumps(tools)) == [ - "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," - "你的任务是针对用户的问题和要求提供适当的答复和支持。" - "\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" + "## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( json.dumps(tools[0], indent=4) ) ]