fix baichuan template #481
This commit is contained in:
parent
7f35487c4a
commit
273135f595
|
@ -195,8 +195,14 @@ def get_template_and_fix_tokenizer(
|
||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
|
additional_special_tokens = template.stop_words
|
||||||
|
|
||||||
if len(template.stop_words): # inplace method
|
if len(template.stop_words): # inplace method
|
||||||
|
if tokenizer.eos_token_id is not None:
|
||||||
|
additional_special_tokens.append(tokenizer.eos_token)
|
||||||
|
|
||||||
tokenizer.eos_token = template.stop_words[0]
|
tokenizer.eos_token = template.stop_words[0]
|
||||||
|
additional_special_tokens.pop(0)
|
||||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None:
|
if tokenizer.eos_token_id is None:
|
||||||
|
@ -210,7 +216,7 @@ def get_template_and_fix_tokenizer(
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||||
|
|
||||||
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
|
tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens))
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue