diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 4ae95a62..e868afd6 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -44,7 +44,7 @@ def init_adapter( raise ValueError("You can only use lora for quantized models.") if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: - logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.") + logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") cast_trainable_params_to_fp32 = False else: logger.info("Upcasting trainable params to float32.") @@ -122,6 +122,9 @@ def init_adapter( else: param.requires_grad_(False) + if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model + model.vision_tower.requires_grad_(False) + logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) if finetuning_args.finetuning_type == "lora": diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 08cdf17f..49b347d5 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -170,6 +170,7 @@ def load_model( ) else: param_stats = "all params: {:d}".format(all_param) + logger.info(param_stats) if model_args.print_param_status: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 603e3c9e..9297ef00 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict import torch from peft import PeftModel from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available -from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger @@ -66,13 +66,16 @@ def patch_config( for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, model_args.compute_dtype == dtype) - if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: - setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn + if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": + setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn - init_kwargs["torch_dtype"] = model_args.compute_dtype - if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): - init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage - if init_kwargs["low_cpu_mem_usage"]: + # deepspeed zero3 is not compatible with low_cpu_mem_usage + init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) + + if deepspeed_config() is None and not is_fsdp_enabled(): # set dtype and device map if not use deepspeed or fsdp + init_kwargs["torch_dtype"] = model_args.compute_dtype + + if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True if "device_map" not in init_kwargs and model_args.device_map: init_kwargs["device_map"] = model_args.device_map