Implemented the tool_formatter and tool_extractor for glm4 tool_format

This commit is contained in:
mMrBun 2024-06-09 18:16:15 +08:00
parent 8bf9da659c
commit cb1cbcb293
2 changed files with 43 additions and 2 deletions

View File

@ -23,6 +23,17 @@ TOOL_SYSTEM_PROMPT = (
)
GLM4_TOOL_SUFFIX_PROMPT = (
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
"{tool_text}"
)
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
tool_names = []
@ -53,6 +64,14 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
)
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_name = tool["name"]
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
action_match = re.search(regex, content)
@ -69,10 +88,24 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
return tool_name, json.dumps(arguments, ensure_ascii=False)
def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
lines = content.strip().split("\n")
if len(lines) != 2:
return content
tool_name = lines[0].strip()
tool_input = lines[1].strip()
try:
arguments = json.loads(tool_input)
except json.JSONDecodeError:
return content
return tool_name, json.dumps(arguments, ensure_ascii=False)
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default"]] = None
tool_format: Optional[Literal["default", "glm4"]] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@ -175,6 +208,11 @@ class ToolFormatter(Formatter):
if self.tool_format == "default":
return [default_tool_formatter(tools)]
elif self.tool_format == "glm4":
"""
'[gMASK]<sop><|system|>\n你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n## get_current_weather\n\n{\n "name": "get_current_weather",\n "description": "Get the current weather",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The city and state, e.g. San Francisco, CA"\n },\n "format": {\n "type": "string",\n "enum": [\n "celsius",\n "fahrenheit"\n ],\n "description": "The temperature unit to use. Infer this from the users location."\n }\n },\n "required": [\n "location",\n "format"\n ]\n }\n}\n在调用上述函数时,请使用 Json 格式表示调用的参数。<|user|>\nWhat\'s the weather like in San Francisco, Tokyo, and Paris? use celsius<|assistant|>'
"""
return [glm4_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
@ -183,5 +221,7 @@ class ToolFormatter(Formatter):
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
if self.tool_format == "default":
return default_tool_extractor(content)
elif self.tool_format == "glm4":
return glm4_tool_extractor(content)
else:
raise NotImplementedError

View File

@ -662,9 +662,10 @@ _register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop><|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,