This commit is contained in:
hiyouga 2024-06-25 02:34:04 +08:00
parent cc016461e6
commit 1e9d0aa1e4
1 changed files with 5 additions and 0 deletions

View File

@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -175,6 +176,10 @@ def load_model(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval() model.eval()
else: else:
model.train() model.train()