add gemma test

This commit is contained in:
hiyouga 2024-07-14 18:01:45 +08:00
parent 173921419d
commit f1d8d29bc3
1 changed files with 32 additions and 9 deletions

View File

@ -43,21 +43,33 @@ def _check_tokenization(
assert tokenizer.decode(input_ids) == text assert tokenizer.decode(input_ids) == text
def _check_single_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool): def _check_single_template(
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=os.environ.get("HF_TOKEN", None)) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=os.environ.get("HF_TOKEN", None))
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False).rstrip("\n") # avoid extra newline content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.encode(content_str, add_special_tokens=False) content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, name=template_name) template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str assert content_str == prompt_str + answer_str + extra_str
assert content_ids == prompt_ids + answer_ids assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
return content_ids return content_ids
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str): def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = ""):
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=False) """
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=True) Checks template for both the slow tokenizer and the fast tokenizer.
Args:
model_id: the model id on hugging face hub.
template_name: the template name.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
extra_str: the extra string in the jinja template of the original tokenizer.
"""
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False)
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True)
assert slow_ids == fast_ids assert slow_ids == fast_ids
@ -107,6 +119,17 @@ def test_jinja_template(use_fast: bool):
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
def test_gemma_template():
prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
"<start_of_turn>user\n你好<end_of_turn>\n"
"<start_of_turn>model\n"
)
answer_str = "很高兴认识你!"
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
def test_llama3_template(): def test_llama3_template():
prompt_str = ( prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
@ -127,7 +150,7 @@ def test_qwen_template():
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = "很高兴认识你!<|im_end|>" answer_str = "很高兴认识你!<|im_end|>"
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str) _check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.") @pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")