From 80a9e6bf94cf14fa63e6b6cdf7e1ce13722c8b5e Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sat, 15 Jun 2024 20:06:17 +0800 Subject: [PATCH] use fixture --- src/llamafactory/hparams/model_args.py | 2 +- tests/model/test_base.py | 15 ++++++++++----- tests/model/test_lora.py | 14 +++++++++++++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 0a91f0fa..53bdbdf2 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -163,7 +163,7 @@ class ModelArguments: ) infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( default="auto", - metadata={"help": "Data type for model weights and activations at inference."} + metadata={"help": "Data type for model weights and activations at inference."}, ) hf_hub_token: Optional[str] = field( default=None, diff --git a/tests/model/test_base.py b/tests/model/test_base.py index ee0b2886..2deedde2 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -15,6 +15,7 @@ import os from typing import Dict +import pytest import torch from transformers import AutoModelForCausalLM from trl import AutoModelForCausalLMWithValueHead @@ -43,10 +44,14 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True -def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): - state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} - self.v_head.load_state_dict(state_dict, strict=False) - del state_dict +@pytest.fixture +def fix_valuehead_cpu_loading(): + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): + state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + AutoModelForCausalLMWithValueHead.post_init = post_init def test_base(): @@ -60,8 +65,8 @@ def test_base(): compare_model(model, ref_model) +@pytest.mark.usefixtures("fix_valuehead_cpu_loading") def test_valuehead(): - AutoModelForCausalLMWithValueHead.post_init = post_init # patch for CPU test model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) tokenizer_module = load_tokenizer(model_args) model = load_model( diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index 4923c8ad..fe032332 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -13,8 +13,9 @@ # limitations under the License. import os -from typing import Sequence +from typing import Dict, Sequence +import pytest import torch from peft import LoraModel, PeftModel from transformers import AutoModelForCausalLM @@ -71,6 +72,16 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True +@pytest.fixture +def fix_valuehead_cpu_loading(): + def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): + state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + AutoModelForCausalLMWithValueHead.post_init = post_init + + def test_lora_train_qv_modules(): model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS}) tokenizer_module = load_tokenizer(model_args) @@ -166,6 +177,7 @@ def test_lora_train_new_adapters(): ) +@pytest.mark.usefixtures("fix_valuehead_cpu_loading") def test_lora_train_valuehead(): model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) tokenizer_module = load_tokenizer(model_args)