diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 9e60175b..22a24339 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -54,7 +54,7 @@ class HuggingfaceEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" if self.can_generate else "right" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.model = load_model( self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) ) # must after fixing tokenizer to resize vocab diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 2626d612..f0d23676 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -59,7 +59,7 @@ class VllmEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) self.generating_args = generating_args.to_dict() engine_args = { diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b5bf688c..3d8ded3b 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -379,6 +379,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") def get_template_and_fix_tokenizer( tokenizer: "PreTrainedTokenizer", name: Optional[str] = None, + tool_format: Optional[str] = None, ) -> Template: if name is None: template = TEMPLATES["empty"] # placeholder @@ -386,6 +387,9 @@ def get_template_and_fix_tokenizer( template = TEMPLATES.get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) + + if tool_format: + template.format_tools = ToolFormatter(tool_format=tool_format) stop_words = template.stop_words if template.replace_eos: @@ -660,7 +664,6 @@ _register_template( format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), - format_tools=ToolFormatter(tool_format="glm4"), format_prefix=EmptyFormatter(slots=["[gMASK]"]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 39290e21..959742e3 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -29,6 +29,10 @@ class DataArguments: default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}, ) + tool_format: Optional[str] = field( + default=None, + metadata={"help": "Specifies the tool format template for function calling ."}, + ) dataset: Optional[str] = field( default=None, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},