From 6d2bf216ac3a48450e861148ce664dad717fd019 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 19 Jun 2024 03:49:23 +0800 Subject: [PATCH] fix bug --- src/llamafactory/data/template.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index a12e9c88..c9af9605 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -102,7 +102,10 @@ class Template: system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): - elements = self.format_prefix.apply() + elements = [] + + if i == 0: + elements += self.format_prefix.apply() if i == 0 and (system or tools): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" @@ -194,7 +197,10 @@ class Llama2Template(Template): system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): - elements = self.format_prefix.apply() + elements = [] + + if i == 0: + elements += self.format_prefix.apply() system_text = "" if i == 0 and (system or tools):