use prefix to replace force system
This commit is contained in:
parent
cd75b1fe9d
commit
4f22eae8f4
|
@ -38,12 +38,12 @@ class Template:
|
|||
format_observation: "Formatter"
|
||||
format_tools: "Formatter"
|
||||
format_separator: "Formatter"
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
image_token: str
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
force_system: bool
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
|
@ -102,8 +102,9 @@ class Template:
|
|||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
elements = self.format_prefix.apply()
|
||||
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
|
||||
|
@ -193,9 +194,10 @@ class Llama2Template(Template):
|
|||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
elements = self.format_prefix.apply()
|
||||
|
||||
system_text = ""
|
||||
if i == 0 and (system or tools or self.force_system):
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
|
||||
|
@ -230,12 +232,12 @@ def _register_template(
|
|||
format_observation: Optional["Formatter"] = None,
|
||||
format_tools: Optional["Formatter"] = None,
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: List[str] = [],
|
||||
image_token: str = "<image>",
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
force_system: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
|
@ -272,6 +274,7 @@ def _register_template(
|
|||
)
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
default_prefix_formatter = EmptyFormatter()
|
||||
TEMPLATES[name] = template_class(
|
||||
format_user=format_user or default_user_formatter,
|
||||
format_assistant=format_assistant or default_assistant_formatter,
|
||||
|
@ -280,12 +283,12 @@ def _register_template(
|
|||
format_observation=format_observation or format_user or default_user_formatter,
|
||||
format_tools=format_tools or default_tool_formatter,
|
||||
format_separator=format_separator or default_separator_formatter,
|
||||
format_prefix=format_prefix or default_prefix_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
image_token=image_token,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
force_system=force_system,
|
||||
)
|
||||
|
||||
|
||||
|
@ -329,7 +332,7 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
|||
|
||||
|
||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||
jinja_template = ""
|
||||
jinja_template = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||
|
||||
if template.default_system:
|
||||
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
||||
|
@ -339,11 +342,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
|||
)
|
||||
|
||||
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
if isinstance(template, Llama2Template):
|
||||
pass
|
||||
elif template.force_system:
|
||||
jinja_template += "{{ " + system_message + " }}"
|
||||
else:
|
||||
if not isinstance(template, Llama2Template):
|
||||
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||
|
||||
jinja_template += "{% for message in messages %}"
|
||||
|
@ -459,9 +458,8 @@ _register_template(
|
|||
_register_template(
|
||||
name="belle",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -486,10 +484,9 @@ _register_template(
|
|||
_register_template(
|
||||
name="chatglm2",
|
||||
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -497,14 +494,14 @@ _register_template(
|
|||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -512,13 +509,12 @@ _register_template(
|
|||
name="chatglm3_system",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
default_system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
|
@ -553,8 +549,7 @@ _register_template(
|
|||
|
||||
_register_template(
|
||||
name="codegeex2",
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -581,8 +576,7 @@ _register_template(
|
|||
_register_template(
|
||||
name="cpm",
|
||||
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -615,8 +609,7 @@ _register_template(
|
|||
_register_template(
|
||||
name="deepseek",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -648,9 +641,8 @@ _register_template(
|
|||
name="empty",
|
||||
format_user=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -672,13 +664,12 @@ _register_template(
|
|||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -686,13 +677,13 @@ _register_template(
|
|||
name="glm4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["[gMASK]<sop><|system|>\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -768,24 +759,21 @@ _register_template(
|
|||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="olmo",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -799,18 +787,16 @@ _register_template(
|
|||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
|
@ -852,7 +838,6 @@ _register_template(
|
|||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue