From 1e9d0aa1e45fac52614e79a9fe87e8f1d3757333 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 25 Jun 2024 02:34:04 +0800 Subject: [PATCH] fix #4432 --- src/llamafactory/model/loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 69cccd93..e1015821 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict +import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead @@ -175,6 +176,10 @@ def load_model( if not is_trainable: model.requires_grad_(False) + for param in model.parameters(): + if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + model.eval() else: model.train()