fix full/freeze tuning for mllm

This commit is contained in:
hiyouga 2024-05-27 20:37:57 +08:00
parent 838f2fb3e4
commit 08564838bd
7 changed files with 76 additions and 61 deletions

View File

@ -38,20 +38,6 @@
"assistant_tag": "assistant" "assistant_tag": "assistant"
} }
}, },
"mllm_pt_demo": {
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": { "alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en", "hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en" "ms_hub_url": "llamafactory/alpaca_en"
@ -322,6 +308,20 @@
"assistant_tag": "assistant" "assistant_tag": "assistant"
} }
}, },
"mllm_pt_demo": {
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"oasst_de": { "oasst_de": {
"hf_hub_url": "mayflowergmbh/oasst_de" "hf_hub_url": "mayflowergmbh/oasst_de"
}, },

View File

@ -311,6 +311,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
) )
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
)
train_mm_proj_only: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
)
plot_loss: bool = field( plot_loss: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},
@ -328,6 +336,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
self.lora_target = split_arg(self.lora_target) self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target) self.additional_target = split_arg(self.additional_target)
self.galore_target = split_arg(self.galore_target) self.galore_target = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
@ -345,7 +354,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full": if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.") raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam): if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.") raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
@ -354,4 +363,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
raise ValueError("Cannot use GaLore with BAdam together.") raise ValueError("Cannot use GaLore with BAdam together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.") raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")

View File

@ -85,10 +85,6 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
) )
tune_mm_proj: bool = field(
default=False,
metadata={"help": "Whethor or not only finetune mm_projector for MLLM."},
)
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: Optional[float] = field(
default=None, default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},

View File

@ -10,7 +10,6 @@ from ..extras.logging import get_logger
from .utils.misc import find_all_linear_modules, find_expanded_modules from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .utils.visual import filter_vision_tower_linear
if TYPE_CHECKING: if TYPE_CHECKING:
@ -53,21 +52,33 @@ def init_adapter(
if finetuning_args.finetuning_type == "full" and is_trainable: if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
if cast_trainable_params_to_fp32:
model = model.float()
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model forbidden_modules = set()
model.vision_tower.requires_grad_(False) if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
model.language_model.requires_grad_(False) forbidden_modules.add("language_model")
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
if finetuning_args.finetuning_type == "freeze" and is_trainable: if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze") logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
else:
config = model.config
num_layers = ( num_layers = (
getattr(model.config, "num_hidden_layers", None) getattr(config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None) or getattr(config, "num_layers", None)
or getattr(model.config, "n_layer", None) or getattr(config, "n_layer", None)
) )
if not num_layers: if not num_layers:
raise ValueError("Current model does not support freeze tuning.") raise ValueError("Current model does not support freeze tuning.")
@ -119,16 +130,19 @@ def init_adapter(
trainable_layers.append(module_name) trainable_layers.append(module_name)
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers): if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
if cast_trainable_params_to_fp32: if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
else: else:
param.requires_grad_(False) param.requires_grad_(False)
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
@ -177,15 +191,15 @@ def init_adapter(
if is_trainable and adapter_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model) target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else: else:
target_modules = finetuning_args.lora_target target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs: if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
target_modules = filter_vision_tower_linear(target_modules) target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
if ( if (
finetuning_args.use_dora finetuning_args.use_dora

View File

@ -1,9 +1,6 @@
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
import torch
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .quantization import QuantizationMethod
if TYPE_CHECKING: if TYPE_CHECKING:
@ -13,29 +10,28 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r""" r"""
Finds all available modules to apply lora or galore. Finds all available modules to apply lora or galore.
""" """
quantization_method = getattr(model, "quantization_method", None) forbidden_modules = {"lm_head"}
if quantization_method is None:
linear_cls = torch.nn.Linear
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm": if model.config.model_type == "chatglm":
output_layer_names.append("output_layer") forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2": elif model.config.model_type == "internlm2":
output_layer_names.append("output") forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector")
if freeze_vision_tower:
forbidden_modules.add("vision_tower")
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names): if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names))) logger.info("Found linear modules: {}".format(",".join(module_names)))

View File

@ -35,6 +35,8 @@ class QuantizationMethod(str, Enum):
AWQ = "awq" AWQ = "awq"
AQLM = "aqlm" AQLM = "aqlm"
QUANTO = "quanto" QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple, List from typing import TYPE_CHECKING, Tuple
import torch import torch
import transformers.models import transformers.models
@ -82,8 +82,3 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def filter_vision_tower_linear(target_modules: List[str]) -> str:
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
return target_modules