diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 77694c59..a12e9c88 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -38,12 +38,12 @@ class Template: format_observation: "Formatter" format_tools: "Formatter" format_separator: "Formatter" + format_prefix: "Formatter" default_system: str stop_words: List[str] image_token: str efficient_eos: bool replace_eos: bool - force_system: bool def encode_oneturn( self, @@ -102,8 +102,9 @@ class Template: system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): - elements = [] - if i == 0 and (system or tools or self.force_system): + elements = self.format_prefix.apply() + + if i == 0 and (system or tools): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) @@ -193,9 +194,10 @@ class Llama2Template(Template): system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): - elements = [] + elements = self.format_prefix.apply() + system_text = "" - if i == 0 and (system or tools or self.force_system): + if i == 0 and (system or tools): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" system_text = self.format_system.apply(content=(system + tool_text))[0] @@ -230,12 +232,12 @@ def _register_template( format_observation: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None, format_separator: Optional["Formatter"] = None, + format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: List[str] = [], image_token: str = "", efficient_eos: bool = False, replace_eos: bool = False, - force_system: bool = False, ) -> None: r""" Registers a chat template. @@ -272,6 +274,7 @@ def _register_template( ) default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() + default_prefix_formatter = EmptyFormatter() TEMPLATES[name] = template_class( format_user=format_user or default_user_formatter, format_assistant=format_assistant or default_assistant_formatter, @@ -280,12 +283,12 @@ def _register_template( format_observation=format_observation or format_user or default_user_formatter, format_tools=format_tools or default_tool_formatter, format_separator=format_separator or default_separator_formatter, + format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, image_token=image_token, efficient_eos=efficient_eos, replace_eos=replace_eos, - force_system=force_system, ) @@ -329,7 +332,7 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: - jinja_template = "" + jinja_template = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) if template.default_system: jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" @@ -339,11 +342,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") ) system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") - if isinstance(template, Llama2Template): - pass - elif template.force_system: - jinja_template += "{{ " + system_message + " }}" - else: + if not isinstance(template, Llama2Template): jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" jinja_template += "{% for message in messages %}" @@ -459,9 +458,8 @@ _register_template( _register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_separator=EmptyFormatter(slots=["\n\n"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -486,10 +484,9 @@ _register_template( _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_separator=EmptyFormatter(slots=["\n\n"]), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), efficient_eos=True, - force_system=True, ) @@ -497,14 +494,14 @@ _register_template( name="chatglm3", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter( slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, - force_system=True, ) @@ -512,13 +509,12 @@ _register_template( name="chatglm3_system", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_system=StringFormatter( - slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] - ), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter( slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), default_system=( "You are ChatGLM3, a large language model trained by Zhipu.AI. " "Follow the user's instructions carefully. Respond using markdown." @@ -553,8 +549,7 @@ _register_template( _register_template( name="codegeex2", - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), ) @@ -581,8 +576,7 @@ _register_template( _register_template( name="cpm", format_user=StringFormatter(slots=["<用户>{{content}}"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -615,8 +609,7 @@ _register_template( _register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -648,9 +641,8 @@ _register_template( name="empty", format_user=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), efficient_eos=True, - force_system=True, ) @@ -672,13 +664,12 @@ _register_template( _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_observation=StringFormatter( slots=["tool\n{{content}}\nmodel\n"] ), format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), efficient_eos=True, - force_system=True, ) @@ -686,13 +677,13 @@ _register_template( name="glm4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_assistant=StringFormatter(slots=["\n{{content}}"]), - format_system=StringFormatter(slots=["[gMASK]<|system|>\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, - force_system=True, ) @@ -768,24 +759,21 @@ _register_template( _register_template( name="mistral", format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) _register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), - format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"eos_token"}]), ) _register_template( name="openchat", format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -799,18 +787,16 @@ _register_template( ) ] ), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=["<|eot_id|>"], replace_eos=True, - force_system=True, ) _register_template( name="orion", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), - format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), - force_system=True, + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), ) @@ -852,7 +838,6 @@ _register_template( format_separator=EmptyFormatter(slots=["\n"]), stop_words=["<|end|>"], replace_eos=True, - force_system=True, )