From 88a20ba7972c533d650967a118d612471fe2b2e8 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 14 Jul 2024 15:34:22 +0800 Subject: [PATCH] fix #4699 slow tokenizer for yi models --- src/llamafactory/hparams/parser.py | 16 +++-- tests/data/test_template.py | 98 ++++++++++++++++++++++++------ 2 files changed, 92 insertions(+), 22 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index ca9a9589..d4ac405a 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -79,7 +79,11 @@ def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: transformers.utils.logging.enable_explicit_format() -def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: +def _verify_model_args( + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", +) -> None: if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": raise ValueError("Adapter is only valid for the LoRA method.") @@ -99,6 +103,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("Quantized model only accepts a single adapter. Merge them first.") + if data_args.template == "yi" and model_args.use_fast_tokenizer: + logger.warning("We should use slow tokenizer for the Yi models.") + model_args.use_fast_tokenizer = False + def _check_extra_dependencies( model_args: "ModelArguments", @@ -237,7 +245,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: logger.warning("`neat_packing` requires `packing` is True. Change it to True.") data_args.packing = True - _verify_model_args(model_args, finetuning_args) + _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args, training_args) if ( @@ -361,7 +369,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: if finetuning_args.stage == "rm" and model_args.visual_inputs: raise ValueError("Reward server does not support MLLM yet. Stay tuned.") - _verify_model_args(model_args, finetuning_args) + _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) if model_args.export_dir is not None and model_args.export_device == "cpu": @@ -384,7 +392,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") - _verify_model_args(model_args, finetuning_args) + _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) model_args.device_map = "auto" diff --git a/tests/data/test_template.py b/tests/data/test_template.py index e4728a84..fa82973b 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -13,12 +13,18 @@ # limitations under the License. import os +from typing import TYPE_CHECKING, Sequence +import pytest from transformers import AutoTokenizer from llamafactory.data import get_template_and_fix_tokenizer +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") MESSAGES = [ @@ -29,52 +35,108 @@ MESSAGES = [ ] -def test_encode_oneturn(): - tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) +def _check_tokenization( + tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str] +): + for input_ids, text in zip(batch_input_ids, batch_text): + assert input_ids == tokenizer.encode(text, add_special_tokens=False) + assert tokenizer.decode(input_ids) == text + + +def _check_single_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool): + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast) + content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False).rstrip("\n") # avoid extra newline + content_ids = tokenizer.encode(content_str, add_special_tokens=False) + template = get_template_and_fix_tokenizer(tokenizer, name=template_name) + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) + assert content_str == prompt_str + answer_str + assert content_ids == prompt_ids + answer_ids + _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) + return content_ids + + +def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str): + slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=False) + fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=True) + assert slow_ids == fast_ids + + +@pytest.mark.parametrize("use_fast", [True, False]) +def test_encode_oneturn(use_fast: bool): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) template = get_template_and_fix_tokenizer(tokenizer, name="llama3") prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) - assert tokenizer.decode(prompt_ids) == ( + prompt_str = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) - assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>" + answer_str = "很高兴认识你!<|eot_id|>" + _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) -def test_encode_multiturn(): - tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) +@pytest.mark.parametrize("use_fast", [True, False]) +def test_encode_multiturn(use_fast: bool): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) template = get_template_and_fix_tokenizer(tokenizer, name="llama3") encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) - assert tokenizer.decode(encoded_pairs[0][0]) == ( + prompt_str_1 = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) - assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>" - assert tokenizer.decode(encoded_pairs[1][0]) == ( + answer_str_1 = "I am fine!<|eot_id|>" + prompt_str_2 = ( "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) - assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>" + answer_str_2 = "很高兴认识你!<|eot_id|>" + _check_tokenization( + tokenizer, + (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]), + (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2), + ) -def test_jinja_template(): - tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) - ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) +@pytest.mark.parametrize("use_fast", [True, False]) +def test_jinja_template(use_fast: bool): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) + ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) get_template_and_fix_tokenizer(tokenizer, name="llama3") assert tokenizer.chat_template != ref_tokenizer.chat_template assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) +def test_llama3_template(): + prompt_str = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + answer_str = "很高兴认识你!<|eot_id|>" + _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str) + + def test_qwen_template(): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct") - template = get_template_and_fix_tokenizer(tokenizer, name="qwen") - prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) - assert tokenizer.decode(prompt_ids) == ( + prompt_str = ( "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" "<|im_start|>user\nHow are you<|im_end|>\n" "<|im_start|>assistant\nI am fine!<|im_end|>\n" "<|im_start|>user\n你好<|im_end|>\n" "<|im_start|>assistant\n" ) - assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>" + answer_str = "很高兴认识你!<|im_end|>" + _check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str) + + +@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.") +def test_yi_template(): + prompt_str = ( + "<|im_start|>user\nHow are you<|im_end|>\n" + "<|im_start|>assistant\nI am fine!<|im_end|>\n" + "<|im_start|>user\n你好<|im_end|>\n" + "<|im_start|>assistant\n" + ) + answer_str = "很高兴认识你!<|im_end|>" + _check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)