use prefix to replace force system

This commit is contained in:
hiyouga 2024-06-19 03:39:52 +08:00
parent cd75b1fe9d
commit 4f22eae8f4
1 changed files with 30 additions and 45 deletions

View File

@ -38,12 +38,12 @@ class Template:
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
image_token: str
efficient_eos: bool
replace_eos: bool
force_system: bool
def encode_oneturn(
self,
@ -102,8 +102,9 @@ class Template:
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
elements = self.format_prefix.apply()
if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
@ -193,9 +194,10 @@ class Llama2Template(Template):
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
elements = self.format_prefix.apply()
system_text = ""
if i == 0 and (system or tools or self.force_system):
if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
@ -230,12 +232,12 @@ def _register_template(
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
) -> None:
r"""
Registers a chat template.
@ -272,6 +274,7 @@ def _register_template(
)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
@ -280,12 +283,12 @@ def _register_template(
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
)
@ -329,7 +332,7 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
jinja_template = ""
jinja_template = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
@ -339,11 +342,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if isinstance(template, Llama2Template):
pass
elif template.force_system:
jinja_template += "{{ " + system_message + " }}"
else:
if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}"
@ -459,9 +458,8 @@ _register_template(
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@ -486,10 +484,9 @@ _register_template(
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
force_system=True,
)
@ -497,14 +494,14 @@ _register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
@ -512,13 +509,12 @@ _register_template(
name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown."
@ -553,8 +549,7 @@ _register_template(
_register_template(
name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
)
@ -581,8 +576,7 @@ _register_template(
_register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@ -615,8 +609,7 @@ _register_template(
_register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@ -648,9 +641,8 @@ _register_template(
name="empty",
format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
force_system=True,
)
@ -672,13 +664,12 @@ _register_template(
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
force_system=True,
)
@ -686,13 +677,13 @@ _register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop><|system|>\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
@ -768,24 +759,21 @@ _register_template(
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
)
_register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@ -799,18 +787,16 @@ _register_template(
)
]
),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
force_system=True,
)
_register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
@ -852,7 +838,6 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
force_system=True,
)