add gemma test
This commit is contained in:
parent
173921419d
commit
f1d8d29bc3
|
@ -43,21 +43,33 @@ def _check_tokenization(
|
||||||
assert tokenizer.decode(input_ids) == text
|
assert tokenizer.decode(input_ids) == text
|
||||||
|
|
||||||
|
|
||||||
def _check_single_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool):
|
def _check_single_template(
|
||||||
|
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
|
||||||
|
):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=os.environ.get("HF_TOKEN", None))
|
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=os.environ.get("HF_TOKEN", None))
|
||||||
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False).rstrip("\n") # avoid extra newline
|
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
|
||||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
|
template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
|
||||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||||
assert content_str == prompt_str + answer_str
|
assert content_str == prompt_str + answer_str + extra_str
|
||||||
assert content_ids == prompt_ids + answer_ids
|
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
|
||||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||||
return content_ids
|
return content_ids
|
||||||
|
|
||||||
|
|
||||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str):
|
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = ""):
|
||||||
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=False)
|
"""
|
||||||
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=True)
|
Checks template for both the slow tokenizer and the fast tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: the model id on hugging face hub.
|
||||||
|
template_name: the template name.
|
||||||
|
prompt_str: the string corresponding to the prompt part.
|
||||||
|
answer_str: the string corresponding to the answer part.
|
||||||
|
extra_str: the extra string in the jinja template of the original tokenizer.
|
||||||
|
"""
|
||||||
|
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False)
|
||||||
|
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True)
|
||||||
assert slow_ids == fast_ids
|
assert slow_ids == fast_ids
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,6 +119,17 @@ def test_jinja_template(use_fast: bool):
|
||||||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemma_template():
|
||||||
|
prompt_str = (
|
||||||
|
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
|
||||||
|
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
|
||||||
|
"<start_of_turn>user\n你好<end_of_turn>\n"
|
||||||
|
"<start_of_turn>model\n"
|
||||||
|
)
|
||||||
|
answer_str = "很高兴认识你!"
|
||||||
|
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
|
||||||
|
|
||||||
|
|
||||||
def test_llama3_template():
|
def test_llama3_template():
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||||
|
@ -127,7 +150,7 @@ def test_qwen_template():
|
||||||
"<|im_start|>assistant\n"
|
"<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<|im_end|>"
|
answer_str = "很高兴认识你!<|im_end|>"
|
||||||
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str)
|
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")
|
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")
|
||||||
|
|
Loading…
Reference in New Issue