fix tool formatter, allow parallel function #4362
This commit is contained in:
parent
c0ca42566c
commit
cd75b1fe9d
|
@ -92,9 +92,11 @@ def _process_request(
|
|||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
name = message.tool_calls[0].function.name
|
||||
arguments = message.tool_calls[0].function.arguments
|
||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||
tool_calls = [
|
||||
{"name": tool_call.function.name, "argument": tool_call.function.arguments}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||
elif isinstance(message.content, list):
|
||||
for input_item in message.content:
|
||||
|
@ -118,7 +120,7 @@ def _process_request(
|
|||
if isinstance(tool_list, list) and len(tool_list):
|
||||
try:
|
||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||
except Exception:
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||
else:
|
||||
tools = None
|
||||
|
@ -160,17 +162,16 @@ async def create_chat_completion_response(
|
|||
choices = []
|
||||
for i, response in enumerate(responses):
|
||||
if tools:
|
||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
||||
result = chat_model.engine.template.extract_tool(response.response_text)
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, list):
|
||||
tool_calls = []
|
||||
for tool in result:
|
||||
name, arguments = tool
|
||||
function = Function(name=name, arguments=arguments)
|
||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
||||
tool_calls.append(tool_call)
|
||||
function = Function(name=tool[0], arguments=tool[1])
|
||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
|
|
|
@ -22,29 +22,20 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni
|
|||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
JSON_FORMAT_PROMPT = (
|
||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
||||
)
|
||||
|
||||
|
||||
TOOL_SYSTEM_PROMPT = (
|
||||
DEFAULT_TOOL_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}]).\n"
|
||||
"Action Input: the input to the tool{format_prompt}.\n"
|
||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
GLM4_TOOL_SUFFIX_PROMPT = (
|
||||
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
||||
)
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
|
||||
"{tool_text}"
|
||||
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -73,32 +64,19 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return TOOL_SYSTEM_PROMPT.format(
|
||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
||||
)
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
|
||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_name = tool["name"]
|
||||
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL)
|
||||
action_match = re.findall(regex, content)
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
results = []
|
||||
|
||||
for match in action_match:
|
||||
tool_name, tool_input = match
|
||||
tool_name = tool_name.strip()
|
||||
tool_input = tool_input.strip().strip('"').strip("```")
|
||||
|
||||
tool_name = match[0].strip()
|
||||
tool_input = match[1].strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
|
@ -108,19 +86,28 @@ def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
|||
return results
|
||||
|
||||
|
||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||
)
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
|
||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) != 2:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
tool_name = lines[0].strip()
|
||||
tool_input = lines[1].strip()
|
||||
|
||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
|
@ -193,22 +180,28 @@ class FunctionFormatter(Formatter):
|
|||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
functions: List[Tuple[str, str]] = []
|
||||
try:
|
||||
function = json.loads(content)
|
||||
name = function["name"]
|
||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
||||
except Exception:
|
||||
name, arguments = "", ""
|
||||
tool_calls = json.loads(content)
|
||||
if not isinstance(tool_calls, list): # parallel function call
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
functions = []
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
for name, arguments in functions:
|
||||
for slot in self.slots:
|
||||
if isinstance(slot, str):
|
||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
|
||||
return elements
|
||||
|
||||
|
@ -216,29 +209,22 @@ class FunctionFormatter(Formatter):
|
|||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format is None:
|
||||
if self.tool_format == "default":
|
||||
self._tool_formatter = default_tool_formatter
|
||||
self._tool_extractor = default_tool_extractor
|
||||
elif self.tool_format == "glm4":
|
||||
self._tool_formatter = glm4_tool_formatter
|
||||
self._tool_extractor = glm4_tool_extractor
|
||||
else:
|
||||
raise ValueError("Tool format was not found.")
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
if not len(tools):
|
||||
return [""]
|
||||
|
||||
if self.tool_format == "default":
|
||||
return [default_tool_formatter(tools)]
|
||||
elif self.tool_format == "glm4":
|
||||
return [glm4_tool_formatter(tools)]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except Exception:
|
||||
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
elif self.tool_format == "glm4":
|
||||
return glm4_tool_extractor(content)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return self._tool_extractor(content)
|
||||
|
|
|
@ -79,6 +79,12 @@ class Template:
|
|||
"""
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
r"""
|
||||
Extracts tool message.
|
||||
"""
|
||||
return self.format_tools.extract(content)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
|
@ -100,7 +106,8 @@ class Template:
|
|||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
elif i > 0 and i % 2 == 0:
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
|
@ -191,7 +198,8 @@ class Llama2Template(Template):
|
|||
if i == 0 and (system or tools or self.force_system):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
elif i > 0 and i % 2 == 0:
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER.value:
|
||||
|
@ -259,7 +267,9 @@ def _register_template(
|
|||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(
|
||||
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
|
||||
)
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
TEMPLATES[name] = template_class(
|
||||
|
|
|
@ -140,16 +140,15 @@ class WebChatModel(ChatModel):
|
|||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
result = self.engine.template.format_tools.extract(response)
|
||||
result = self.engine.template.extract_tool(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
if isinstance(result, list):
|
||||
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
|
||||
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
|
||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
||||
bot_text = "```json\n" + tool_calls + "\n```"
|
||||
else:
|
||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
|
||||
|
||||
def test_empty_formatter():
|
||||
formatter = EmptyFormatter(slots=["\n"])
|
||||
assert formatter.apply() == ["\n"]
|
||||
|
||||
|
||||
def test_string_formatter():
|
||||
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
|
||||
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
|
||||
|
||||
|
||||
def test_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
|
||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
|
||||
]
|
||||
|
||||
|
||||
def test_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
|
||||
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
||||
]
|
||||
|
||||
|
||||
def test_default_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
tools = [
|
||||
{
|
||||
"name": "test_tool",
|
||||
"description": "tool_desc",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"type": "string", "description": "foo_desc"},
|
||||
"bar": {"type": "number", "description": "bar_desc"},
|
||||
},
|
||||
"required": ["foo"],
|
||||
},
|
||||
}
|
||||
]
|
||||
assert formatter.apply(content=json.dumps(tools)) == [
|
||||
"You have access to the following tools:\n"
|
||||
"> Tool Name: test_tool\n"
|
||||
"Tool Description: tool_desc\n"
|
||||
"Tool Args:\n"
|
||||
" - foo (string, required): foo_desc\n"
|
||||
" - bar (number): bar_desc\n\n"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [test_tool]).\n"
|
||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||
"""(e.g. ```{"input": "hello world", "num_beams": 5}```).\n"""
|
||||
"```\n"
|
||||
]
|
||||
|
||||
|
||||
def test_default_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_default_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = (
|
||||
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
def test_glm4_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
tools = [
|
||||
{
|
||||
"name": "test_tool",
|
||||
"description": "tool_desc",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"foo": {"type": "string", "description": "foo_desc"},
|
||||
"bar": {"type": "number", "description": "bar_desc"},
|
||||
},
|
||||
"required": ["foo"],
|
||||
},
|
||||
}
|
||||
]
|
||||
assert formatter.apply(content=json.dumps(tools)) == [
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。"
|
||||
"\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
json.dumps(tools[0], indent=4)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_glm4_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
Loading…
Reference in New Issue