fix test case

This commit is contained in:
hiyouga 2024-06-25 02:51:49 +08:00
parent 1e9d0aa1e4
commit c244af0dc3
1 changed files with 2 additions and 1 deletions

View File

@ -73,7 +73,8 @@ def test_valuehead():
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False, add_valuehead=True
) )
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained( ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device() TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
) )
ref_model.v_head = ref_model.v_head.to(torch.float16)
compare_model(model, ref_model) compare_model(model, ref_model)