fix qwen template

This commit is contained in:
hiyouga 2024-01-05 16:14:56 +08:00
parent 33f2c0d4f8
commit ed216bbc46
1 changed files with 6 additions and 13 deletions

View File

@ -1,5 +1,4 @@
import tiktoken import tiktoken
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@ -224,12 +223,13 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None) template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name) assert template is not None, "Template {} does not exist.".format(name)
stop_words = deepcopy(template.stop_words) stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:
if not stop_words: if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.") raise ValueError("Stop words are required to replace the EOS token.")
tokenizer.eos_token = stop_words.pop(0) tokenizer.eos_token = stop_words[0]
stop_words = stop_words[1:]
logger.info("Replace eos token: {}".format(tokenizer.eos_token)) logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if stop_words: if stop_words:
@ -601,26 +601,19 @@ register_template(
register_template( register_template(
name="qwen", name="qwen",
prefix=[ prefix=[
{"token": "<|im_start|>"}, "<|im_start|>system\n{{system}}<|im_end|>"
"system\n{{system}}"
], ],
prompt=[ prompt=[
{"token": "<|im_start|>"}, "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
], ],
system="You are a helpful assistant.", system="You are a helpful assistant.",
sep=[ sep=[
{"token": "<|im_end|>"},
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<|im_end|>" "<|im_end|>"
], ],
efficient_eos=True replace_eos=True
) )