fix baichuan template #481

This commit is contained in:
hiyouga 2023-08-15 11:38:21 +08:00
parent 7f35487c4a
commit 273135f595
1 changed files with 7 additions and 1 deletions

View File

@ -195,8 +195,14 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None) template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name) assert template is not None, "Template {} does not exist.".format(name)
additional_special_tokens = template.stop_words
if len(template.stop_words): # inplace method if len(template.stop_words): # inplace method
if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token)
tokenizer.eos_token = template.stop_words[0] tokenizer.eos_token = template.stop_words[0]
additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None: if tokenizer.eos_token_id is None:
@ -210,7 +216,7 @@ def get_template_and_fix_tokenizer(
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words)) tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens))
return template return template