Update test_template.py
This commit is contained in:
parent
da990f76b8
commit
10289eab15
|
@ -26,7 +26,6 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
||||
raise ValueError("test: " + str(HF_TOKEN))
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
|
@ -122,7 +121,7 @@ def test_jinja_template(use_fast: bool):
|
|||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||
|
||||
|
||||
@pytest.mark.skipif(HF_TOKEN is None, reason="Gated model.")
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_gemma_template():
|
||||
prompt_str = (
|
||||
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
|
||||
|
@ -134,7 +133,7 @@ def test_gemma_template():
|
|||
_check_template("google/gemma-2-9b-it", "gemma", prompt_str, answer_str, extra_str="<end_of_turn>\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(HF_TOKEN is None, reason="Gated model.")
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_llama3_template():
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
|
|
Loading…
Reference in New Issue