use prefix to replace force system

This commit is contained in:
hiyouga 2024-06-19 03:39:52 +08:00
parent cd75b1fe9d
commit 4f22eae8f4
1 changed files with 30 additions and 45 deletions

View File

@ -38,12 +38,12 @@ class Template:
format_observation: "Formatter" format_observation: "Formatter"
format_tools: "Formatter" format_tools: "Formatter"
format_separator: "Formatter" format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
image_token: str image_token: str
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
force_system: bool
def encode_oneturn( def encode_oneturn(
self, self,
@ -102,8 +102,9 @@ class Template:
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = self.format_prefix.apply()
if i == 0 and (system or tools or self.force_system):
if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
@ -193,9 +194,10 @@ class Llama2Template(Template):
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = self.format_prefix.apply()
system_text = "" system_text = ""
if i == 0 and (system or tools or self.force_system): if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
@ -230,12 +232,12 @@ def _register_template(
format_observation: Optional["Formatter"] = None, format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None, format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: List[str] = [], stop_words: List[str] = [],
image_token: str = "<image>", image_token: str = "<image>",
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
force_system: bool = False,
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
@ -272,6 +274,7 @@ def _register_template(
) )
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class( TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter, format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter, format_assistant=format_assistant or default_assistant_formatter,
@ -280,12 +283,12 @@ def _register_template(
format_observation=format_observation or format_user or default_user_formatter, format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter, format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter, format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words, stop_words=stop_words,
image_token=image_token, image_token=image_token,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
force_system=force_system,
) )
@ -329,7 +332,7 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
jinja_template = "" jinja_template = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
if template.default_system: if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
@ -339,11 +342,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
) )
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if isinstance(template, Llama2Template): if not isinstance(template, Llama2Template):
pass
elif template.force_system:
jinja_template += "{{ " + system_message + " }}"
else:
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}" jinja_template += "{% for message in messages %}"
@ -459,9 +458,8 @@ _register_template(
_register_template( _register_template(
name="belle", name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
force_system=True, format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
@ -486,10 +484,9 @@ _register_template(
_register_template( _register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@ -497,14 +494,14 @@ _register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
), ),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"], stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@ -512,13 +509,12 @@ _register_template(
name="chatglm3_system", name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter( format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
), ),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
default_system=( default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. " "You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown." "Follow the user's instructions carefully. Respond using markdown."
@ -553,8 +549,7 @@ _register_template(
_register_template( _register_template(
name="codegeex2", name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
force_system=True,
) )
@ -581,8 +576,7 @@ _register_template(
_register_template( _register_template(
name="cpm", name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]), format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
@ -615,8 +609,7 @@ _register_template(
_register_template( _register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
@ -648,9 +641,8 @@ _register_template(
name="empty", name="empty",
format_user=StringFormatter(slots=["{{content}}"]), format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@ -672,13 +664,12 @@ _register_template(
_register_template( _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@ -686,13 +677,13 @@ _register_template(
name="glm4", name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]), format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop><|system|>\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"), format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"], stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@ -768,24 +759,21 @@ _register_template(
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
_register_template( _register_template(
name="olmo", name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
force_system=True,
) )
_register_template( _register_template(
name="openchat", name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
@ -799,18 +787,16 @@ _register_template(
) )
] ]
), ),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
force_system=True,
) )
_register_template( _register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
@ -852,7 +838,6 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True, replace_eos=True,
force_system=True,
) )