slow tokenizer for yi models
This commit is contained in:
hiyouga 2024-07-14 15:34:22 +08:00
parent d3c01552e0
commit 88a20ba797
2 changed files with 92 additions and 22 deletions

View File

@ -79,7 +79,11 @@ def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
def _verify_model_args(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
) -> None:
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
@ -99,6 +103,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning("We should use slow tokenizer for the Yi models.")
model_args.use_fast_tokenizer = False
def _check_extra_dependencies(
model_args: "ModelArguments",
@ -237,7 +245,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
logger.warning("`neat_packing` requires `packing` is True. Change it to True.")
data_args.packing = True
_verify_model_args(model_args, finetuning_args)
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
@ -361,7 +369,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, finetuning_args)
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None and model_args.export_device == "cpu":
@ -384,7 +392,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, finetuning_args)
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"

View File

@ -13,12 +13,18 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Sequence
import pytest
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MESSAGES = [
@ -29,52 +35,108 @@ MESSAGES = [
]
def test_encode_oneturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
):
for input_ids, text in zip(batch_input_ids, batch_text):
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
assert tokenizer.decode(input_ids) == text
def _check_single_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False).rstrip("\n") # avoid extra newline
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str
assert content_ids == prompt_ids + answer_ids
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
return content_ids
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str):
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=False)
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=True)
assert slow_ids == fast_ids
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert tokenizer.decode(prompt_ids) == (
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>"
answer_str = "很高兴认识你!<|eot_id|>"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
def test_encode_multiturn():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
assert tokenizer.decode(encoded_pairs[0][0]) == (
prompt_str_1 = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>"
assert tokenizer.decode(encoded_pairs[1][0]) == (
answer_str_1 = "I am fine!<|eot_id|>"
prompt_str_2 = (
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>"
answer_str_2 = "很高兴认识你!<|eot_id|>"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
def test_jinja_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
get_template_and_fix_tokenizer(tokenizer, name="llama3")
assert tokenizer.chat_template != ref_tokenizer.chat_template
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
def test_llama3_template():
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = "很高兴认识你!<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
def test_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert tokenizer.decode(prompt_ids) == (
prompt_str = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"
answer_str = "很高兴认识你!<|im_end|>"
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str)
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")
def test_yi_template():
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "很高兴认识你!<|im_end|>"
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)