From 273135f59500a36cc30333ef2dd3689c6030e2ef Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 15 Aug 2023 11:38:21 +0800 Subject: [PATCH] fix baichuan template #481 --- src/llmtuner/extras/template.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 62153b0d..8ad42ac8 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -195,8 +195,14 @@ def get_template_and_fix_tokenizer( template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) + additional_special_tokens = template.stop_words + if len(template.stop_words): # inplace method + if tokenizer.eos_token_id is not None: + additional_special_tokens.append(tokenizer.eos_token) + tokenizer.eos_token = template.stop_words[0] + additional_special_tokens.pop(0) logger.info("Replace eos token: {}".format(tokenizer.eos_token)) if tokenizer.eos_token_id is None: @@ -210,7 +216,7 @@ def get_template_and_fix_tokenizer( tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) - tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words)) + tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens)) return template