use fixture
This commit is contained in:
parent
1b834f50be
commit
80a9e6bf94
|
@ -163,7 +163,7 @@ class ModelArguments:
|
|||
)
|
||||
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations at inference."}
|
||||
metadata={"help": "Data type for model weights and activations at inference."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
@ -43,10 +44,14 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
|||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||
|
||||
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||
|
||||
|
||||
def test_base():
|
||||
|
@ -60,8 +65,8 @@ def test_base():
|
|||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_valuehead():
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init # patch for CPU test
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Sequence
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from peft import LoraModel, PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
@ -71,6 +72,16 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
|||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||
|
||||
|
||||
def test_lora_train_qv_modules():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
|
@ -166,6 +177,7 @@ def test_lora_train_new_adapters():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_lora_train_valuehead():
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
|
|
Loading…
Reference in New Issue