fix test case
This commit is contained in:
parent
1e9d0aa1e4
commit
c244af0dc3
|
@ -73,7 +73,8 @@ def test_valuehead():
|
||||||
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
|
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||||
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
|
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
|
||||||
)
|
)
|
||||||
|
ref_model.v_head = ref_model.v_head.to(torch.float16)
|
||||||
compare_model(model, ref_model)
|
compare_model(model, ref_model)
|
||||||
|
|
Loading…
Reference in New Issue