parent
d3c01552e0
commit
88a20ba797
|
@ -79,7 +79,11 @@ def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
|||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
|
||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||
def _verify_model_args(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> None:
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||
|
||||
|
@ -99,6 +103,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
|||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||
logger.warning("We should use slow tokenizer for the Yi models.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
|
@ -237,7 +245,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
logger.warning("`neat_packing` requires `packing` is True. Change it to True.")
|
||||
data_args.packing = True
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if (
|
||||
|
@ -361,7 +369,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
||||
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
|
@ -384,7 +392,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
model_args.device_map = "auto"
|
||||
|
|
|
@ -13,12 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
MESSAGES = [
|
||||
|
@ -29,52 +35,108 @@ MESSAGES = [
|
|||
]
|
||||
|
||||
|
||||
def test_encode_oneturn():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
def _check_tokenization(
|
||||
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
|
||||
):
|
||||
for input_ids, text in zip(batch_input_ids, batch_text):
|
||||
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
|
||||
assert tokenizer.decode(input_ids) == text
|
||||
|
||||
|
||||
def _check_single_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast)
|
||||
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False).rstrip("\n") # avoid extra newline
|
||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
assert content_str == prompt_str + answer_str
|
||||
assert content_ids == prompt_ids + answer_ids
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
return content_ids
|
||||
|
||||
|
||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str):
|
||||
slow_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=False)
|
||||
fast_ids = _check_single_template(model_id, template_name, prompt_str, answer_str, use_fast=True)
|
||||
assert slow_ids == fast_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_oneturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
assert tokenizer.decode(prompt_ids) == (
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|eot_id|>"
|
||||
answer_str = "很高兴认识你!<|eot_id|>"
|
||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||
|
||||
|
||||
def test_encode_multiturn():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_encode_multiturn(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
||||
assert tokenizer.decode(encoded_pairs[0][0]) == (
|
||||
prompt_str_1 = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert tokenizer.decode(encoded_pairs[0][1]) == "I am fine!<|eot_id|>"
|
||||
assert tokenizer.decode(encoded_pairs[1][0]) == (
|
||||
answer_str_1 = "I am fine!<|eot_id|>"
|
||||
prompt_str_2 = (
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert tokenizer.decode(encoded_pairs[1][1]) == "很高兴认识你!<|eot_id|>"
|
||||
answer_str_2 = "很高兴认识你!<|eot_id|>"
|
||||
_check_tokenization(
|
||||
tokenizer,
|
||||
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
||||
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
||||
)
|
||||
|
||||
|
||||
def test_jinja_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_jinja_template(use_fast: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
|
||||
get_template_and_fix_tokenizer(tokenizer, name="llama3")
|
||||
assert tokenizer.chat_template != ref_tokenizer.chat_template
|
||||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||
|
||||
|
||||
def test_llama3_template():
|
||||
prompt_str = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|eot_id|>"
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str)
|
||||
|
||||
|
||||
def test_qwen_template():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
|
||||
template = get_template_and_fix_tokenizer(tokenizer, name="qwen")
|
||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
||||
assert tokenizer.decode(prompt_ids) == (
|
||||
prompt_str = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
assert tokenizer.decode(answer_ids) == "很高兴认识你!<|im_end|>"
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")
|
||||
def test_yi_template():
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
||||
"<|im_start|>user\n你好<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
answer_str = "很高兴认识你!<|im_end|>"
|
||||
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)
|
||||
|
|
Loading…
Reference in New Issue