Merge pull request #4417 from mMrBun/main

Add tool_format parameter to rewrite templates for different function call formats.
This commit is contained in:
hoshi-hiyouga 2024-06-24 23:17:55 +08:00 committed by GitHub
commit def6d280db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 14 additions and 6 deletions

View File

@ -54,7 +54,7 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.model = load_model( self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab ) # must after fixing tokenizer to resize vocab

View File

@ -59,7 +59,7 @@ class VllmEngine(BaseEngine):
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]
self.processor = tokenizer_module["processor"] self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
engine_args = { engine_args = {

View File

@ -148,7 +148,7 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template) template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")

View File

@ -379,6 +379,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
def get_template_and_fix_tokenizer( def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None, name: Optional[str] = None,
tool_format: Optional[str] = None,
) -> Template: ) -> Template:
if name is None: if name is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
@ -386,6 +387,9 @@ def get_template_and_fix_tokenizer(
template = TEMPLATES.get(name, None) template = TEMPLATES.get(name, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(name)) raise ValueError("Template {} does not exist.".format(name))
if tool_format:
template.format_tools = ToolFormatter(tool_format=tool_format)
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:

View File

@ -29,6 +29,10 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."}, metadata={"help": "Which template to use for constructing prompts in training and inference."},
) )
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Specifies the tool format template for function calling ."},
)
dataset: Optional[str] = field( dataset: Optional[str] = field(
default=None, default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},

View File

@ -111,9 +111,9 @@ def test_glm4_tool_formatter():
} }
] ]
assert formatter.apply(content=json.dumps(tools)) == [ assert formatter.apply(content=json.dumps(tools)) == [
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。" "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
"\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( "## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
json.dumps(tools[0], indent=4) json.dumps(tools[0], indent=4)
) )
] ]