From f1d8d29bc3ba1b41a72a24834a16a3d125d56461 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 14 Jul 2024 18:01:45 +0800 Subject: [PATCH] add gemma test --- tests/data/test_template.py | 41 +++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index ceb8acc4..3dd83546 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -43,21 +43,33 @@ def _check_tokenization( assert tokenizer.decode(input_ids) == text -def _check_single_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool): +def _check_single_template( + model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool +): tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=os.environ.get("HF_TOKEN", None)) - content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False).rstrip("\n") # avoid extra newline - content_ids = tokenizer.encode(content_str, add_special_tokens=False) + content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) + content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True) template = get_template_and_fix_tokenizer(tokenizer, name=template_name) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) - assert content_str == prompt_str + answer_str - assert content_ids == prompt_ids + answer_ids + assert content_str == prompt_str + answer_str + extra_str + assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) return content_ids -def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str): - slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=False) - fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=True) +def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = ""): + """ + Checks template for both the slow tokenizer and the fast tokenizer. + + Args: + model_id: the model id on hugging face hub. + template_name: the template name. + prompt_str: the string corresponding to the prompt part. + answer_str: the string corresponding to the answer part. + extra_str: the extra string in the jinja template of the original tokenizer. + """ + slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=False) + fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, extra_str, use_fast=True) assert slow_ids == fast_ids @@ -107,6 +119,17 @@ def test_jinja_template(use_fast: bool): assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) +def test_gemma_template(): + prompt_str = ( + "user\nHow are you\n" + "model\nI am fine!\n" + "user\n你好\n" + "model\n" + ) + answer_str = "很高兴认识你!" + _check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="\n") + + def test_llama3_template(): prompt_str = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" @@ -127,7 +150,7 @@ def test_qwen_template(): "<|im_start|>assistant\n" ) answer_str = "很高兴认识你!<|im_end|>" - _check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str) + _check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n") @pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")