diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 0ffb91c1..83f9a2d2 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING import torch from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model -from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger from .utils.misc import find_all_linear_modules, find_expanded_modules @@ -41,9 +42,16 @@ def init_adapter( if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): raise ValueError("You can only use lora for quantized models.") + if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: + logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.") + cast_trainable_params_to_fp32 = False + else: + logger.info("Upcasting trainable params to float32.") + cast_trainable_params_to_fp32 = True + if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") - if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): + if cast_trainable_params_to_fp32: model = model.float() if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model @@ -93,7 +101,7 @@ def init_adapter( for name, param in model.named_parameters(): if any(trainable_layer in name for trainable_layer in trainable_layers): - if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): + if cast_trainable_params_to_fp32: param.data = param.data.to(torch.float32) else: param.requires_grad_(False) @@ -191,7 +199,7 @@ def init_adapter( ) model = get_peft_model(model, lora_config) - if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): + if cast_trainable_params_to_fp32: for param in filter(lambda p: p.requires_grad, model.parameters()): param.data = param.data.to(torch.float32)