From cb1cbcb293917e960cad8f0eac7a11a122ab644a Mon Sep 17 00:00:00 2001 From: mMrBun <2015711377@qq.com> Date: Sun, 9 Jun 2024 18:16:15 +0800 Subject: [PATCH 1/3] Implemented the tool_formatter and tool_extractor for glm4 tool_format --- src/llamafactory/data/formatter.py | 42 +++++++++++++++++++++++++++++- src/llamafactory/data/template.py | 3 ++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 0cd3d6c1..344e01db 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -23,6 +23,17 @@ TOOL_SYSTEM_PROMPT = ( ) +GLM4_TOOL_SUFFIX_PROMPT = ( + "在调用上述函数时,请使用 Json 格式表示调用的参数。" +) + +GLM4_TOOL_PROMPT = ( + "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持," + "{tool_text}" + +) + + def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: tool_text = "" tool_names = [] @@ -53,6 +64,14 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: ) +def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool_name = tool["name"] + tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}" + return GLM4_TOOL_PROMPT.format(tool_text=tool_text) + + def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL) action_match = re.search(regex, content) @@ -69,10 +88,24 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: return tool_name, json.dumps(arguments, ensure_ascii=False) +def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: + lines = content.strip().split("\n") + if len(lines) != 2: + return content + tool_name = lines[0].strip() + tool_input = lines[1].strip() + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + return tool_name, json.dumps(arguments, ensure_ascii=False) + + + @dataclass class Formatter(ABC): slots: SLOTS = field(default_factory=list) - tool_format: Optional[Literal["default"]] = None + tool_format: Optional[Literal["default", "glm4"]] = None @abstractmethod def apply(self, **kwargs) -> SLOTS: ... @@ -175,6 +208,11 @@ class ToolFormatter(Formatter): if self.tool_format == "default": return [default_tool_formatter(tools)] + elif self.tool_format == "glm4": + """ + '[gMASK]<|system|>\n你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n## get_current_weather\n\n{\n "name": "get_current_weather",\n "description": "Get the current weather",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The city and state, e.g. San Francisco, CA"\n },\n "format": {\n "type": "string",\n "enum": [\n "celsius",\n "fahrenheit"\n ],\n "description": "The temperature unit to use. Infer this from the users location."\n }\n },\n "required": [\n "location",\n "format"\n ]\n }\n}\n在调用上述函数时,请使用 Json 格式表示调用的参数。<|user|>\nWhat\'s the weather like in San Francisco, Tokyo, and Paris? use celsius<|assistant|>' + """ + return [glm4_tool_formatter(tools)] else: raise NotImplementedError except Exception: @@ -183,5 +221,7 @@ class ToolFormatter(Formatter): def extract(self, content: str) -> Union[str, Tuple[str, str]]: if self.tool_format == "default": return default_tool_extractor(content) + elif self.tool_format == "glm4": + return glm4_tool_extractor(content) else: raise NotImplementedError diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 3dce5ec6..b2aea217 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -662,9 +662,10 @@ _register_template( name="glm4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_assistant=StringFormatter(slots=["\n{{content}}"]), - format_system=StringFormatter(slots=["[gMASK]{{content}}"]), + format_system=StringFormatter(slots=["[gMASK]<|system|>\n{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, force_system=True, From 6ed0b0c800d416379acf8395aa852c188c107eb9 Mon Sep 17 00:00:00 2001 From: mMrBun <2015711377@qq.com> Date: Sun, 9 Jun 2024 18:25:22 +0800 Subject: [PATCH 2/3] Removed unnecessary comments. --- src/llamafactory/data/formatter.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 344e01db..9f58915b 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -209,9 +209,6 @@ class ToolFormatter(Formatter): if self.tool_format == "default": return [default_tool_formatter(tools)] elif self.tool_format == "glm4": - """ - '[gMASK]<|system|>\n你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n## get_current_weather\n\n{\n "name": "get_current_weather",\n "description": "Get the current weather",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The city and state, e.g. San Francisco, CA"\n },\n "format": {\n "type": "string",\n "enum": [\n "celsius",\n "fahrenheit"\n ],\n "description": "The temperature unit to use. Infer this from the users location."\n }\n },\n "required": [\n "location",\n "format"\n ]\n }\n}\n在调用上述函数时,请使用 Json 格式表示调用的参数。<|user|>\nWhat\'s the weather like in San Francisco, Tokyo, and Paris? use celsius<|assistant|>' - """ return [glm4_tool_formatter(tools)] else: raise NotImplementedError From 950e360ca00c29febadc14d5995de7d57b5c43a7 Mon Sep 17 00:00:00 2001 From: mMrBun <2015711377@qq.com> Date: Mon, 10 Jun 2024 02:00:14 +0800 Subject: [PATCH 3/3] Optimize the handling of QWEN2 in scenarios involving multiple tool calls. --- src/llamafactory/api/chat.py | 13 +++++++----- src/llamafactory/data/formatter.py | 34 ++++++++++++++++++------------ 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 98957bc1..d4db1eea 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -150,11 +150,14 @@ async def create_chat_completion_response( else: result = response.response_text - if isinstance(result, tuple): - name, arguments = result - function = Function(name=name, arguments=arguments) - tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) - response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) + if isinstance(result, list): + tool_calls = [] + for tool in result: + name, arguments = tool + function = Function(name=name, arguments=arguments) + tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) + tool_calls.append(tool_call) + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) finish_reason = Finish.TOOL else: response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 9f58915b..1d917887 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -72,23 +72,29 @@ def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str: return GLM4_TOOL_PROMPT.format(tool_text=tool_text) -def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: - regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL) - action_match = re.search(regex, content) +def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL) + action_match = re.findall(regex, content) if not action_match: return content - tool_name = action_match.group(1).strip() - tool_input = action_match.group(2).strip().strip('"').strip("```") - try: - arguments = json.loads(tool_input) - except json.JSONDecodeError: - return content + results = [] + + for match in action_match: + tool_name, tool_input = match + tool_name = tool_name.strip() + tool_input = tool_input.strip().strip('"').strip("```") - return tool_name, json.dumps(arguments, ensure_ascii=False) + try: + arguments = json.loads(tool_input) + results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) + except json.JSONDecodeError: + return content + + return results -def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: +def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: lines = content.strip().split("\n") if len(lines) != 2: return content @@ -98,7 +104,7 @@ def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: arguments = json.loads(tool_input) except json.JSONDecodeError: return content - return tool_name, json.dumps(arguments, ensure_ascii=False) + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] @@ -110,7 +116,7 @@ class Formatter(ABC): @abstractmethod def apply(self, **kwargs) -> SLOTS: ... - def extract(self, content: str) -> Union[str, Tuple[str, str]]: + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: raise NotImplementedError @@ -215,7 +221,7 @@ class ToolFormatter(Formatter): except Exception: return [""] - def extract(self, content: str) -> Union[str, Tuple[str, str]]: + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: if self.tool_format == "default": return default_tool_extractor(content) elif self.tool_format == "glm4":