forked from p04798526/LLaMA-Factory-Mirror
parent
a388cadfc0
commit
44cfa9a1cd
|
@ -2,7 +2,8 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
|
@ -41,9 +42,16 @@ def init_adapter(
|
||||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
||||||
raise ValueError("You can only use lora for quantized models.")
|
raise ValueError("You can only use lora for quantized models.")
|
||||||
|
|
||||||
|
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
|
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.")
|
||||||
|
cast_trainable_params_to_fp32 = False
|
||||||
|
else:
|
||||||
|
logger.info("Upcasting trainable params to float32.")
|
||||||
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
if cast_trainable_params_to_fp32:
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||||
|
@ -93,7 +101,7 @@ def init_adapter(
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
if cast_trainable_params_to_fp32:
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
@ -191,7 +199,7 @@ def init_adapter(
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
if cast_trainable_params_to_fp32:
|
||||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue