parent
a388cadfc0
commit
44cfa9a1cd
|
@ -2,7 +2,8 @@ from typing import TYPE_CHECKING
|
|||
|
||||
import torch
|
||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
|
@ -41,9 +42,16 @@ def init_adapter(
|
|||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||
raise ValueError("You can only use lora for quantized models.")
|
||||
|
||||
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
if cast_trainable_params_to_fp32:
|
||||
model = model.float()
|
||||
|
||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||
|
@ -93,7 +101,7 @@ def init_adapter(
|
|||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
@ -191,7 +199,7 @@ def init_adapter(
|
|||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||
if cast_trainable_params_to_fp32:
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
|
|
Loading…
Reference in New Issue