fix yi template #1895

This commit is contained in:
hiyouga 2023-12-20 18:58:16 +08:00
parent 624cc21281
commit 5af8841c4f
1 changed files with 28 additions and 13 deletions

View File

@ -21,6 +21,7 @@ class Template:
stop_words: List[str]
use_history: bool
efficient_eos: bool
replace_eos: bool
def encode_oneturn(
self,
@ -38,7 +39,8 @@ class Template:
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
@ -77,13 +79,13 @@ class Template:
) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
bos_ids = [tokenizer.bos_token_id]
else: # baichuan, qwen and gpt2 models have no bos token
else: # baichuan, gpt2, qwen, yi models have no bos token
bos_ids = []
if tokenizer.eos_token_id is None:
raise ValueError("EOS token is required.")
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
if self.efficient_eos:
eos_ids = []
else:
eos_ids = [tokenizer.eos_token_id]
@ -187,9 +189,10 @@ def register_template(
sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False
efficient_eos: Optional[bool] = False,
replace_eos: Optional[bool] = False
) -> None:
template_class = Llama2Template if "llama2" in name else Template
template_class = Llama2Template if name.startswith("llama2") else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
@ -197,7 +200,8 @@ def register_template(
sep=sep,
stop_words=stop_words,
use_history=use_history,
efficient_eos=efficient_eos
efficient_eos=efficient_eos,
replace_eos=replace_eos
)
@ -213,15 +217,26 @@ def get_template_and_fix_tokenizer(
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None:
if name is None: # for pre-training
return None
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False
)
if template.replace_eos:
if not template.stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
tokenizer.eos_token = template.stop_words.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if template.stop_words:
tokenizer.add_special_tokens(
dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(template.stop_words)))
return template
@ -732,12 +747,12 @@ register_template(
],
system="",
sep=[
"<|im_end|>\n"
"\n"
],
stop_words=[
"<|im_end|>"
],
efficient_eos=True
replace_eos=True
)