add unittest
This commit is contained in:
parent
779aae83d2
commit
608de799a2
File diff suppressed because one or more lines are too long
|
@ -35,7 +35,6 @@ Dependency graph:
|
|||
transformers==4.41.2 (chatglm)
|
||||
"""
|
||||
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
|
||||
|
|
|
@ -119,7 +119,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(train_dataset) / total_train_batch_size)
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(
|
||||
len(train_dataset) / total_train_batch_size
|
||||
)
|
||||
|
||||
optimizer = self.create_optimizer(model, training_args, finetuning_args)
|
||||
scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..data import get_dataset
|
||||
from ..extras.misc import get_current_device
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from peft import LoraModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None:
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
if any(key in name for key in diff_keys):
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
|
||||
else:
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
|
||||
linear_modules, extra_modules = set(), set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
elif "modules_to_save" in name:
|
||||
extra_modules.add(name.split(".modules_to_save", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
return linear_modules, extra_modules
|
||||
|
||||
|
||||
def load_train_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(kwargs)
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
return load_model(tokenizer, model_args, finetuning_args, is_trainable=True, add_valuehead=add_valuehead)
|
||||
|
||||
|
||||
def load_infer_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":
|
||||
model_args, _, finetuning_args, _ = get_infer_args(kwargs)
|
||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
||||
return load_model(tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead)
|
||||
|
||||
|
||||
def load_reference_model(
|
||||
model_path: str,
|
||||
lora_path: Optional[str] = None,
|
||||
use_lora: bool = False,
|
||||
use_pissa: bool = False,
|
||||
is_trainable: bool = False,
|
||||
add_valuehead: bool = False,
|
||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||
if add_valuehead:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
if not is_trainable:
|
||||
model.v_head = model.v_head.to(torch.float16)
|
||||
|
||||
return model
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
if use_lora or use_pissa:
|
||||
model = PeftModel.from_pretrained(
|
||||
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
|
||||
)
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_train_dataset(**kwargs) -> "Dataset":
|
||||
model_args, data_args, training_args, _, _ = get_train_args(kwargs)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module)
|
||||
return dataset_module["train_dataset"]
|
||||
|
||||
|
||||
def patch_valuehead_model():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
|
@ -19,20 +19,21 @@ import pytest
|
|||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.data import get_dataset
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "sft",
|
||||
"stage": "kto",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"dataset": "llamafactory/tiny-supervised-dataset",
|
||||
"dataset_dir": "ONLINE",
|
||||
"dataset": "kto_en_demo",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
|
@ -43,25 +44,16 @@ TRAIN_ARGS = {
|
|||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_supervised(num_samples: int):
|
||||
model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
|
||||
def test_feedback_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
|
||||
original_data = load_dataset(TRAIN_ARGS["dataset"], split="train")
|
||||
original_data = load_dataset(DEMO_DATA, name="kto_en_demo", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
prompt = original_data[index]["instruction"]
|
||||
if original_data[index]["input"]:
|
||||
prompt += "\n" + original_data[index]["input"]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": original_data[index]["output"]},
|
||||
]
|
||||
templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
decoded_result = tokenizer.decode(dataset_module["train_dataset"]["input_ids"][index])
|
||||
assert templated_result == decoded_result
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
||||
assert train_dataset["kto_tags"][index] == original_data["label"][index]
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "rm",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"dataset": "dpo_en_demo",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||
role_mapping = {"human": "user", "gpt": "assistant", "system": "system"}
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
new_messages.append({"role": role_mapping[message["from"]], "content": message["value"]})
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_pairwise_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="dpo_en_demo", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
chosen_messages = original_data["conversations"][index] + [original_data["chosen"][index]]
|
||||
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
|
||||
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
|
||||
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
|
||||
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
|
||||
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True))
|
||||
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
|
||||
ref_rejected_input_ids = ref_tokenizer.apply_chat_template(rejected_messages)
|
||||
rejected_prompt_len = len(
|
||||
ref_tokenizer.apply_chat_template(rejected_messages[:-1], add_generation_prompt=True)
|
||||
)
|
||||
ref_rejected_labels = [IGNORE_INDEX] * rejected_prompt_len + ref_rejected_input_ids[rejected_prompt_len:]
|
||||
assert train_dataset["chosen_input_ids"][index] == ref_chosen_input_ids
|
||||
assert train_dataset["chosen_labels"][index] == ref_chosen_labels
|
||||
assert train_dataset["rejected_input_ids"][index] == ref_rejected_input_ids
|
||||
assert train_dataset["rejected_labels"][index] == ref_rejected_labels
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_DATA = os.environ.get("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "sft",
|
||||
"do_train": True,
|
||||
"finetuning_type": "full",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_supervised_single_turn(num_samples: int):
|
||||
train_dataset = load_train_dataset(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(TINY_DATA, split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
prompt = original_data["instruction"][index]
|
||||
if original_data["input"][index]:
|
||||
prompt += "\n" + original_data["input"][index]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": original_data["output"][index]},
|
||||
]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [8])
|
||||
def test_supervised_multi_turn(num_samples: int):
|
||||
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [4])
|
||||
def test_supervised_train_on_prompt(num_samples: int):
|
||||
train_dataset = load_train_dataset(
|
||||
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", train_on_prompt=True, **TRAIN_ARGS
|
||||
)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
ref_ids = ref_tokenizer.apply_chat_template(original_data["messages"][index])
|
||||
assert train_dataset["input_ids"][index] == ref_ids
|
||||
assert train_dataset["labels"][index] == ref_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [4])
|
||||
def test_supervised_mask_history(num_samples: int):
|
||||
train_dataset = load_train_dataset(
|
||||
dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", mask_history=True, **TRAIN_ARGS
|
||||
)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
|
||||
ref_label_ids = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_label_ids
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.train.test_utils import load_train_dataset
|
||||
|
||||
|
||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
TINY_DATA = os.environ.get("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
||||
|
||||
TRAIN_ARGS = {
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"stage": "sft",
|
||||
"do_predict": True,
|
||||
"finetuning_type": "full",
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
"output_dir": "dummy_dir",
|
||||
"overwrite_output_dir": True,
|
||||
"predict_with_generate": True,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_unsupervised_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
for index in indexes:
|
||||
messages = original_data["messages"][index]
|
||||
ref_ids = ref_tokenizer.apply_chat_template(messages)
|
||||
ref_input_ids = ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)
|
||||
ref_labels = ref_ids[len(ref_input_ids) :]
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
assert train_dataset["labels"][index] == ref_labels
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Sequence
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -39,7 +39,7 @@ MESSAGES = [
|
|||
|
||||
def _check_tokenization(
|
||||
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
|
||||
):
|
||||
) -> None:
|
||||
for input_ids, text in zip(batch_input_ids, batch_text):
|
||||
assert input_ids == tokenizer.encode(text, add_special_tokens=False)
|
||||
assert tokenizer.decode(input_ids) == text
|
||||
|
@ -47,7 +47,7 @@ def _check_tokenization(
|
|||
|
||||
def _check_single_template(
|
||||
model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str, use_fast: bool
|
||||
):
|
||||
) -> List[str]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
|
||||
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
|
||||
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
|
||||
|
@ -59,7 +59,7 @@ def _check_single_template(
|
|||
return content_ids
|
||||
|
||||
|
||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = ""):
|
||||
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, extra_str: str = "") -> None:
|
||||
"""
|
||||
Checks template for both the slow tokenizer and the fast tokenizer.
|
||||
|
||||
|
|
|
@ -16,8 +16,7 @@ import os
|
|||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_infer_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
@ -42,9 +41,7 @@ def test_attention():
|
|||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
for requested_attention in attention_available:
|
||||
model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
|
||||
model = load_infer_model(flash_attn=requested_attention, **INFER_ARGS)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
||||
|
|
|
@ -17,8 +17,7 @@ import os
|
|||
import torch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
@ -41,34 +40,26 @@ TRAIN_ARGS = {
|
|||
|
||||
|
||||
def test_checkpointing_enable():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": False, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is True
|
||||
|
||||
|
||||
def test_checkpointing_disable():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is False
|
||||
|
||||
|
||||
def test_upcast_layernorm():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"upcast_layernorm": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and "norm" in name:
|
||||
assert param.dtype == torch.float32
|
||||
|
||||
|
||||
def test_upcast_lmhead_output():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"upcast_lmhead_output": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
|
||||
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
|
||||
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
|
||||
assert outputs.dtype == torch.float32
|
||||
|
|
|
@ -13,16 +13,15 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import (
|
||||
compare_model,
|
||||
load_infer_model,
|
||||
load_reference_model,
|
||||
patch_valuehead_model,
|
||||
)
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
@ -36,45 +35,19 @@ INFER_ARGS = {
|
|||
}
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
def test_base():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_valuehead():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(
|
||||
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
|
||||
)
|
||||
|
||||
ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
ref_model.v_head = ref_model.v_head.to(torch.float16)
|
||||
model = load_infer_model(add_valuehead=True, **INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, add_valuehead=True)
|
||||
compare_model(model, ref_model)
|
||||
|
|
|
@ -16,8 +16,7 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
@ -46,10 +45,7 @@ INFER_ARGS = {
|
|||
|
||||
|
||||
def test_freeze_train_all_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("model.layers.1."):
|
||||
assert param.requires_grad is True
|
||||
|
@ -60,12 +56,7 @@ def test_freeze_train_all_modules():
|
|||
|
||||
|
||||
def test_freeze_train_extra_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]):
|
||||
assert param.requires_grad is True
|
||||
|
@ -76,10 +67,7 @@ def test_freeze_train_extra_modules():
|
|||
|
||||
|
||||
def test_freeze_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
|
|
@ -16,8 +16,7 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
@ -46,20 +45,14 @@ INFER_ARGS = {
|
|||
|
||||
|
||||
def test_full_train():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
|
||||
|
||||
def test_full_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
for param in model.parameters():
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
|
|
@ -13,17 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from peft import LoraModel, PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import (
|
||||
check_lora_model,
|
||||
compare_model,
|
||||
load_infer_model,
|
||||
load_reference_model,
|
||||
load_train_model,
|
||||
patch_valuehead_model,
|
||||
)
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
@ -56,116 +57,38 @@ INFER_ARGS = {
|
|||
}
|
||||
|
||||
|
||||
def load_reference_model(is_trainable: bool = False) -> "LoraModel":
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
lora_model = PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER, is_trainable=is_trainable)
|
||||
for param in filter(lambda p: p.requires_grad, lora_model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
return lora_model
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []):
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
if any(key in name for key in diff_keys):
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
|
||||
else:
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
def test_lora_train_qv_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
linear_modules = set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
|
||||
linear_modules, _ = check_lora_model(model)
|
||||
assert linear_modules == {"q_proj", "v_proj"}
|
||||
|
||||
|
||||
def test_lora_train_all_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
linear_modules = set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
model = load_train_model(lora_target="all", **TRAIN_ARGS)
|
||||
linear_modules, _ = check_lora_model(model)
|
||||
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
|
||||
|
||||
|
||||
def test_lora_train_extra_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
extra_modules = set()
|
||||
for name, param in model.named_parameters():
|
||||
if any(module in name for module in ["lora_A", "lora_B"]):
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
elif "modules_to_save" in name:
|
||||
extra_modules.add(name.split(".modules_to_save", maxsplit=1)[0].split(".")[-1])
|
||||
assert param.requires_grad is True
|
||||
assert param.dtype == torch.float32
|
||||
else:
|
||||
assert param.requires_grad is False
|
||||
assert param.dtype == torch.float16
|
||||
|
||||
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
|
||||
_, extra_modules = check_lora_model(model)
|
||||
assert extra_modules == {"embed_tokens", "lm_head"}
|
||||
|
||||
|
||||
def test_lora_train_old_adapters():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": False, **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
ref_model = load_reference_model(is_trainable=True)
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
def test_lora_train_new_adapters():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||
{"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": True, **TRAIN_ARGS}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
ref_model = load_reference_model(is_trainable=True)
|
||||
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
|
||||
compare_model(
|
||||
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
|
||||
)
|
||||
|
@ -173,26 +96,15 @@ def test_lora_train_new_adapters():
|
|||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_lora_train_valuehead():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(
|
||||
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True, add_valuehead=True
|
||||
)
|
||||
|
||||
ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, is_trainable=True, add_valuehead=True)
|
||||
state_dict = model.state_dict()
|
||||
ref_state_dict = ref_model.state_dict()
|
||||
|
||||
assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"])
|
||||
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
|
||||
|
||||
|
||||
def test_lora_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
ref_model = load_reference_model().merge_and_unload()
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
|
||||
compare_model(model, ref_model)
|
||||
|
|
|
@ -14,13 +14,7 @@
|
|||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import LoraModel, PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_infer_args, get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
@ -54,37 +48,14 @@ INFER_ARGS = {
|
|||
}
|
||||
|
||||
|
||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||
state_dict_a = model_a.state_dict()
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
|
||||
|
||||
|
||||
def test_pissa_init():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init", is_trainable=True)
|
||||
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
def test_pissa_train():
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
def test_pissa_inference():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device()
|
||||
)
|
||||
ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init")
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
|
||||
ref_model = ref_model.merge_and_unload()
|
||||
compare_model(model, ref_model)
|
||||
|
|
Loading…
Reference in New Issue