From 20e2e6fdcb0cd1771906be035745a2d9fcd3e138 Mon Sep 17 00:00:00 2001 From: mMrBun <2015711377@qq.com> Date: Sat, 22 Jun 2024 02:00:13 +0800 Subject: [PATCH 1/4] Add tool_format to overwrite tool formatter template --- src/llamafactory/chat/hf_engine.py | 2 +- src/llamafactory/chat/vllm_engine.py | 2 +- src/llamafactory/data/template.py | 5 ++++- src/llamafactory/hparams/data_args.py | 4 ++++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 9e60175b..22a24339 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -54,7 +54,7 @@ class HuggingfaceEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" if self.can_generate else "right" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.model = load_model( self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) ) # must after fixing tokenizer to resize vocab diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 2626d612..f0d23676 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -59,7 +59,7 @@ class VllmEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.generating_args = generating_args.to_dict() engine_args = { diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b5bf688c..3d8ded3b 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -379,6 +379,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") def get_template_and_fix_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, + tool_format: Optional[str] = None, ) -> Template: if name is None: template = TEMPLATES["empty"] # placeholder @@ -386,6 +387,9 @@ def get_template_and_fix_tokenizer( template = TEMPLATES.get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) + + if tool_format: + template.format_tools = ToolFormatter(tool_format=tool_format) stop_words = template.stop_words if template.replace_eos: @@ -660,7 +664,6 @@ _register_template( format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), - format_tools=ToolFormatter(tool_format="glm4"), format_prefix=EmptyFormatter(slots=["[gMASK]"]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 39290e21..959742e3 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -29,6 +29,10 @@ class DataArguments: default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}, ) + tool_format: Optional[str] = field( + default=None, + metadata={"help": "Specifies the tool format template for function calling ."}, + ) dataset: Optional[str] = field( default=None, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, From dddfd516ee66e9937e21f05300832aab45034b12 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 24 Jun 2024 23:06:18 +0800 Subject: [PATCH 2/4] Update loader.py --- src/llamafactory/data/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index f44ef5de..8e7062db 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -148,7 +148,7 @@ def get_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, ) -> Union["Dataset", "IterableDataset"]: - template = get_template_and_fix_tokenizer(tokenizer, data_args.template) + template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") From 1240bd57d8a21540c636a6da839e6b3112d1395a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 24 Jun 2024 23:12:59 +0800 Subject: [PATCH 3/4] Update template.py --- src/llamafactory/data/template.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 3d8ded3b..3a72a858 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -664,6 +664,7 @@ _register_template( format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), format_prefix=EmptyFormatter(slots=["[gMASK]"]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, From 672152d2ce6b49d7668c70100d877a1c34c08eae Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 24 Jun 2024 23:14:36 +0800 Subject: [PATCH 4/4] Update test_formatter.py --- tests/data/test_formatter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 430eb0e6..a01e8a7e 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -111,9 +111,9 @@ def test_glm4_tool_formatter(): } ] assert formatter.apply(content=json.dumps(tools)) == [ - "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," - "你的任务是针对用户的问题和要求提供适当的答复和支持。" - "\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" + "## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( json.dumps(tools[0], indent=4) ) ]