From c244af0dc3478532de02271667e7af4ad8f54228 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 25 Jun 2024 02:51:49 +0800 Subject: [PATCH] fix test case --- tests/model/test_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/model/test_base.py b/tests/model/test_base.py index e1991b20..6431a504 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -73,7 +73,8 @@ def test_valuehead(): tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True ) - ref_model = AutoModelForCausalLMWithValueHead.from_pretrained( + ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained( TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device() ) + ref_model.v_head = ref_model.v_head.to(torch.float16) compare_model(model, ref_model)