update test template

This commit is contained in:
hiyouga 2024-07-15 00:49:34 +08:00
parent f1d8d29bc3
commit a4ae3ab4ab
1 changed files with 5 additions and 1 deletions

View File

@ -25,6 +25,8 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
HF_TOKEN = os.environ.get("HF_TOKEN", None)
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MESSAGES = [
@ -46,7 +48,7 @@ def _check_tokenization(
def _check_single_template(
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=os.environ.get("HF_TOKEN", None))
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
@ -119,6 +121,7 @@ def test_jinja_template(use_fast: bool):
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
@pytest.mark.skipif(HF_TOKEN is None, reason="Gated model.")
def test_gemma_template():
prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
@ -130,6 +133,7 @@ def test_gemma_template():
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
@pytest.mark.skipif(HF_TOKEN is None, reason="Gated model.")
def test_llama3_template():
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"