diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index a12e9c88..c9af9605 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -102,7 +102,10 @@ class Template: system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): - elements = self.format_prefix.apply() + elements = [] + + if i == 0: + elements += self.format_prefix.apply() if i == 0 and (system or tools): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" @@ -194,7 +197,10 @@ class Llama2Template(Template): system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): - elements = self.format_prefix.apply() + elements = [] + + if i == 0: + elements += self.format_prefix.apply() system_text = "" if i == 0 and (system or tools):