From 08564838bd02651668845ed74e2e60561e5b6d8c Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 27 May 2024 20:37:57 +0800 Subject: [PATCH] fix full/freeze tuning for mllm --- data/dataset_info.json | 28 +++++------ src/llamafactory/hparams/finetuning_args.py | 16 ++++++- src/llamafactory/hparams/model_args.py | 4 -- src/llamafactory/model/adapter.py | 50 +++++++++++++------- src/llamafactory/model/utils/misc.py | 30 +++++------- src/llamafactory/model/utils/quantization.py | 2 + src/llamafactory/model/utils/visual.py | 7 +-- 7 files changed, 76 insertions(+), 61 deletions(-) diff --git a/data/dataset_info.json b/data/dataset_info.json index 0a148431..2d9b0c83 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -38,20 +38,6 @@ "assistant_tag": "assistant" } }, - "mllm_pt_demo": { - "hf_hub_url": "BUAADreamer/mllm_pt_demo", - "formatting": "sharegpt", - "columns": { - "messages": "messages", - "images": "images" - }, - "tags": { - "role_tag": "role", - "content_tag": "content", - "user_tag": "user", - "assistant_tag": "assistant" - } - }, "alpaca_en": { "hf_hub_url": "llamafactory/alpaca_en", "ms_hub_url": "llamafactory/alpaca_en" @@ -322,6 +308,20 @@ "assistant_tag": "assistant" } }, + "mllm_pt_demo": { + "hf_hub_url": "BUAADreamer/mllm_pt_demo", + "formatting": "sharegpt", + "columns": { + "messages": "messages", + "images": "images" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "user", + "assistant_tag": "assistant" + } + }, "oasst_de": { "hf_hub_url": "mayflowergmbh/oasst_de" }, diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 05b246ae..b9322f18 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -311,6 +311,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, ) + freeze_vision_tower: bool = field( + default=True, + metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, + ) + train_mm_proj_only: bool = field( + default=False, + metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, + ) plot_loss: bool = field( default=False, metadata={"help": "Whether or not to save the training loss curves."}, @@ -328,6 +336,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA self.lora_target = split_arg(self.lora_target) self.additional_target = split_arg(self.additional_target) self.galore_target = split_arg(self.galore_target) + self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." @@ -345,7 +354,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") if self.use_llama_pro and self.finetuning_type == "full": - raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.") + raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") if self.finetuning_type == "lora" and (self.use_galore or self.use_badam): raise ValueError("Cannot use LoRA with GaLore or BAdam together.") @@ -354,4 +363,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA raise ValueError("Cannot use GaLore with BAdam together.") if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": - raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.") + raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + + if self.train_mm_proj_only and self.finetuning_type != "full": + raise ValueError("`train_mm_proj_only` is only valid for full training.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 781c7e99..650d1c22 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -85,10 +85,6 @@ class ModelArguments: default=False, metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, ) - tune_mm_proj: bool = field( - default=False, - metadata={"help": "Whethor or not only finetune mm_projector for MLLM."}, - ) moe_aux_loss_coef: Optional[float] = field( default=None, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 015db8a0..a9204ef0 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -10,7 +10,6 @@ from ..extras.logging import get_logger from .utils.misc import find_all_linear_modules, find_expanded_modules from .utils.quantization import QuantizationMethod from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model -from .utils.visual import filter_vision_tower_linear if TYPE_CHECKING: @@ -53,21 +52,33 @@ def init_adapter( if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") - if cast_trainable_params_to_fp32: - model = model.float() - if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model - model.vision_tower.requires_grad_(False) + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") - if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj - model.language_model.requires_grad_(False) + if model_args.visual_inputs and finetuning_args.train_mm_proj_only: + forbidden_modules.add("language_model") + + for name, param in model.named_parameters(): + if not any(forbidden_module in name for forbidden_module in forbidden_modules): + if cast_trainable_params_to_fp32: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) if finetuning_args.finetuning_type == "freeze" and is_trainable: logger.info("Fine-tuning method: Freeze") + + if model_args.visual_inputs: + config = model.config.text_config + else: + config = model.config + num_layers = ( - getattr(model.config, "num_hidden_layers", None) - or getattr(model.config, "num_layers", None) - or getattr(model.config, "n_layer", None) + getattr(config, "num_hidden_layers", None) + or getattr(config, "num_layers", None) + or getattr(config, "n_layer", None) ) if not num_layers: raise ValueError("Current model does not support freeze tuning.") @@ -119,16 +130,19 @@ def init_adapter( trainable_layers.append(module_name) + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + for name, param in model.named_parameters(): - if any(trainable_layer in name for trainable_layer in trainable_layers): + if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( + forbidden_module in name for forbidden_module in forbidden_modules + ): if cast_trainable_params_to_fp32: param.data = param.data.to(torch.float32) else: param.requires_grad_(False) - if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model - model.vision_tower.requires_grad_(False) - logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) if finetuning_args.finetuning_type == "lora": @@ -177,15 +191,15 @@ def init_adapter( if is_trainable and adapter_to_resume is None: # create new lora weights while training if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": - target_modules = find_all_linear_modules(model) + target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) else: target_modules = finetuning_args.lora_target if finetuning_args.use_llama_pro: - target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) + target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) - if model_args.visual_inputs: - target_modules = filter_vision_tower_linear(target_modules) + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) if ( finetuning_args.use_dora diff --git a/src/llamafactory/model/utils/misc.py b/src/llamafactory/model/utils/misc.py index eca68866..4851bd29 100644 --- a/src/llamafactory/model/utils/misc.py +++ b/src/llamafactory/model/utils/misc.py @@ -1,9 +1,6 @@ from typing import TYPE_CHECKING, List -import torch - from ...extras.logging import get_logger -from .quantization import QuantizationMethod if TYPE_CHECKING: @@ -13,29 +10,28 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: +def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: r""" Finds all available modules to apply lora or galore. """ - quantization_method = getattr(model, "quantization_method", None) - if quantization_method is None: - linear_cls = torch.nn.Linear - elif quantization_method == QuantizationMethod.BITS_AND_BYTES: - import bitsandbytes as bnb + forbidden_modules = {"lm_head"} - linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt - else: - raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method)) - - output_layer_names = ["lm_head"] if model.config.model_type == "chatglm": - output_layer_names.append("output_layer") + forbidden_modules.add("output_layer") elif model.config.model_type == "internlm2": - output_layer_names.append("output") + forbidden_modules.add("output") + elif model.config.model_type in ["llava", "paligemma"]: + forbidden_modules.add("multi_modal_projector") + + if freeze_vision_tower: + forbidden_modules.add("vision_tower") module_names = set() for name, module in model.named_modules(): - if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names): + if any(forbidden_module in name for forbidden_module in forbidden_modules): + continue + + if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: module_names.add(name.split(".")[-1]) logger.info("Found linear modules: {}".format(",".join(module_names))) diff --git a/src/llamafactory/model/utils/quantization.py b/src/llamafactory/model/utils/quantization.py index 161ad5aa..02a54f07 100644 --- a/src/llamafactory/model/utils/quantization.py +++ b/src/llamafactory/model/utils/quantization.py @@ -35,6 +35,8 @@ class QuantizationMethod(str, Enum): AWQ = "awq" AQLM = "aqlm" QUANTO = "quanto" + EETQ = "eetq" + HQQ = "hqq" def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: diff --git a/src/llamafactory/model/utils/visual.py b/src/llamafactory/model/utils/visual.py index a91777ba..c8260b7f 100644 --- a/src/llamafactory/model/utils/visual.py +++ b/src/llamafactory/model/utils/visual.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Tuple, List +from typing import TYPE_CHECKING, Tuple import torch import transformers.models @@ -82,8 +82,3 @@ def configure_visual_model(config: "PretrainedConfig") -> None: if getattr(config, "is_yi_vl_derived_model", None): logger.info("Detected Yi-VL model, applying projector patch.") transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL - - -def filter_vision_tower_linear(target_modules: List[str]) -> str: - target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*" - return target_modules