fix tool formatter, allow parallel function #4362

This commit is contained in:
hiyouga 2024-06-19 03:23:51 +08:00
parent c0ca42566c
commit cd75b1fe9d
5 changed files with 207 additions and 86 deletions

View File

@ -92,9 +92,11 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
tool_calls = [
{"name": tool_call.function.name, "argument": tool_call.function.arguments}
for tool_call in message.tool_calls
]
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list):
for input_item in message.content:
@ -118,7 +120,7 @@ def _process_request(
if isinstance(tool_list, list) and len(tool_list):
try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception:
except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = None
@ -160,17 +162,16 @@ async def create_chat_completion_response(
choices = []
for i, response in enumerate(responses):
if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text)
result = chat_model.engine.template.extract_tool(response.response_text)
else:
result = response.response_text
if isinstance(result, list):
tool_calls = []
for tool in result:
name, arguments = tool
function = Function(name=name, arguments=arguments)
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
tool_calls.append(tool_call)
function = Function(name=tool[0], arguments=tool[1])
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL
else:

View File

@ -22,29 +22,20 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool{format_prompt}.\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
"```\n"
)
GLM4_TOOL_SUFFIX_PROMPT = (
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
"{tool_text}"
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
)
@ -73,32 +64,19 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
)
tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format(
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_name = tool["name"]
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL)
action_match = re.findall(regex, content)
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match:
return content
results = []
for match in action_match:
tool_name, tool_input = match
tool_name = tool_name.strip()
tool_input = tool_input.strip().strip('"').strip("```")
tool_name = match[0].strip()
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
@ -108,19 +86,28 @@ def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
return results
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
lines = content.strip().split("\n")
if len(lines) != 2:
if "\n" not in content:
return content
tool_name = lines[0].strip()
tool_input = lines[1].strip()
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
@dataclass
class Formatter(ABC):
@ -193,22 +180,28 @@ class FunctionFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
try:
function = json.loads(content)
name = function["name"]
arguments = json.dumps(function["arguments"], ensure_ascii=False)
except Exception:
name, arguments = "", ""
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
functions = []
elements = []
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
for name, arguments in functions:
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements
@ -216,29 +209,22 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format is None:
if self.tool_format == "default":
self._tool_formatter = default_tool_formatter
self._tool_extractor = default_tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = glm4_tool_formatter
self._tool_extractor = glm4_tool_extractor
else:
raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
if not len(tools):
return [""]
if self.tool_format == "default":
return [default_tool_formatter(tools)]
elif self.tool_format == "glm4":
return [glm4_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
if self.tool_format == "default":
return default_tool_extractor(content)
elif self.tool_format == "glm4":
return glm4_tool_extractor(content)
else:
raise NotImplementedError
return self._tool_extractor(content)

View File

@ -79,6 +79,12 @@ class Template:
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r"""
Extracts tool message.
"""
return self.format_tools.extract(content)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
@ -100,7 +106,8 @@ class Template:
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@ -191,7 +198,8 @@ class Llama2Template(Template):
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
@ -259,7 +267,9 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_function_formatter = FunctionFormatter(
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(

View File

@ -140,16 +140,15 @@ class WebChatModel(ChatModel):
):
response += new_text
if tools:
result = self.engine.template.format_tools.extract(response)
result = self.engine.template.extract_tool(response)
else:
result = response
if isinstance(result, tuple):
name, arguments = result
arguments = json.loads(arguments)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
bot_text = "```json\n" + tool_call + "\n```"
if isinstance(result, list):
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result

View File

@ -0,0 +1,125 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
def test_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
]
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
]
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
"You have access to the following tools:\n"
"> Tool Name: test_tool\n"
"Tool Description: tool_desc\n"
"Tool Args:\n"
" - foo (string, required): foo_desc\n"
" - bar (number): bar_desc\n\n"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [test_tool]).\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{"input": "hello world", "num_beams": 5}```).\n"""
"```\n"
]
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。"
"\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
json.dumps(tools[0], indent=4)
)
]
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]