hiyouga 2024-05-15 23:05:02 +08:00
parent a388cadfc0
commit 44cfa9a1cd
1 changed files with 12 additions and 4 deletions

View File

@ -2,7 +2,8 @@ from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from .utils.misc import find_all_linear_modules, find_expanded_modules
@ -41,9 +42,16 @@ def init_adapter(
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.")
cast_trainable_params_to_fp32 = False
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
if cast_trainable_params_to_fp32:
model = model.float()
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
@ -93,7 +101,7 @@ def init_adapter(
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
@ -191,7 +199,7 @@ def init_adapter(
)
model = get_peft_model(model, lora_config)
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
if cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)