forked from p04798526/LLaMA-Factory-Mirror
fix full/freeze tuning for mllm
This commit is contained in:
parent
838f2fb3e4
commit
08564838bd
|
@ -38,20 +38,6 @@
|
||||||
"assistant_tag": "assistant"
|
"assistant_tag": "assistant"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"mllm_pt_demo": {
|
|
||||||
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
|
|
||||||
"formatting": "sharegpt",
|
|
||||||
"columns": {
|
|
||||||
"messages": "messages",
|
|
||||||
"images": "images"
|
|
||||||
},
|
|
||||||
"tags": {
|
|
||||||
"role_tag": "role",
|
|
||||||
"content_tag": "content",
|
|
||||||
"user_tag": "user",
|
|
||||||
"assistant_tag": "assistant"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"alpaca_en": {
|
"alpaca_en": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_en",
|
"hf_hub_url": "llamafactory/alpaca_en",
|
||||||
"ms_hub_url": "llamafactory/alpaca_en"
|
"ms_hub_url": "llamafactory/alpaca_en"
|
||||||
|
@ -322,6 +308,20 @@
|
||||||
"assistant_tag": "assistant"
|
"assistant_tag": "assistant"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"mllm_pt_demo": {
|
||||||
|
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "messages",
|
||||||
|
"images": "images"
|
||||||
|
},
|
||||||
|
"tags": {
|
||||||
|
"role_tag": "role",
|
||||||
|
"content_tag": "content",
|
||||||
|
"user_tag": "user",
|
||||||
|
"assistant_tag": "assistant"
|
||||||
|
}
|
||||||
|
},
|
||||||
"oasst_de": {
|
"oasst_de": {
|
||||||
"hf_hub_url": "mayflowergmbh/oasst_de"
|
"hf_hub_url": "mayflowergmbh/oasst_de"
|
||||||
},
|
},
|
||||||
|
|
|
@ -311,6 +311,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
|
freeze_vision_tower: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
||||||
|
)
|
||||||
|
train_mm_proj_only: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
|
||||||
|
)
|
||||||
plot_loss: bool = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
|
@ -328,6 +336,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target = split_arg(self.additional_target)
|
||||||
self.galore_target = split_arg(self.galore_target)
|
self.galore_target = split_arg(self.galore_target)
|
||||||
|
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
@ -345,7 +354,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||||
|
|
||||||
if self.use_llama_pro and self.finetuning_type == "full":
|
if self.use_llama_pro and self.finetuning_type == "full":
|
||||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
|
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
||||||
|
|
||||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
||||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||||
|
@ -354,4 +363,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||||
|
|
||||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||||
|
|
||||||
|
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||||
|
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||||
|
|
|
@ -85,10 +85,6 @@ class ModelArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||||
)
|
)
|
||||||
tune_mm_proj: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whethor or not only finetune mm_projector for MLLM."},
|
|
||||||
)
|
|
||||||
moe_aux_loss_coef: Optional[float] = field(
|
moe_aux_loss_coef: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||||
|
|
|
@ -10,7 +10,6 @@ from ..extras.logging import get_logger
|
||||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .utils.quantization import QuantizationMethod
|
from .utils.quantization import QuantizationMethod
|
||||||
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
from .utils.visual import filter_vision_tower_linear
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -53,21 +52,33 @@ def init_adapter(
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
if cast_trainable_params_to_fp32:
|
|
||||||
model = model.float()
|
|
||||||
|
|
||||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
forbidden_modules = set()
|
||||||
model.vision_tower.requires_grad_(False)
|
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||||
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
|
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
|
||||||
model.language_model.requires_grad_(False)
|
forbidden_modules.add("language_model")
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
|
if cast_trainable_params_to_fp32:
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
else:
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
|
|
||||||
|
if model_args.visual_inputs:
|
||||||
|
config = model.config.text_config
|
||||||
|
else:
|
||||||
|
config = model.config
|
||||||
|
|
||||||
num_layers = (
|
num_layers = (
|
||||||
getattr(model.config, "num_hidden_layers", None)
|
getattr(config, "num_hidden_layers", None)
|
||||||
or getattr(model.config, "num_layers", None)
|
or getattr(config, "num_layers", None)
|
||||||
or getattr(model.config, "n_layer", None)
|
or getattr(config, "n_layer", None)
|
||||||
)
|
)
|
||||||
if not num_layers:
|
if not num_layers:
|
||||||
raise ValueError("Current model does not support freeze tuning.")
|
raise ValueError("Current model does not support freeze tuning.")
|
||||||
|
@ -119,16 +130,19 @@ def init_adapter(
|
||||||
|
|
||||||
trainable_layers.append(module_name)
|
trainable_layers.append(module_name)
|
||||||
|
|
||||||
|
forbidden_modules = set()
|
||||||
|
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||||
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||||
|
forbidden_module in name for forbidden_module in forbidden_modules
|
||||||
|
):
|
||||||
if cast_trainable_params_to_fp32:
|
if cast_trainable_params_to_fp32:
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
|
||||||
model.vision_tower.requires_grad_(False)
|
|
||||||
|
|
||||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
@ -177,15 +191,15 @@ def init_adapter(
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
target_modules = find_all_linear_modules(model)
|
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
||||||
else:
|
else:
|
||||||
target_modules = finetuning_args.lora_target
|
target_modules = finetuning_args.lora_target
|
||||||
|
|
||||||
if finetuning_args.use_llama_pro:
|
if finetuning_args.use_llama_pro:
|
||||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||||
|
|
||||||
if model_args.visual_inputs:
|
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||||
target_modules = filter_vision_tower_linear(target_modules)
|
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.use_dora
|
finetuning_args.use_dora
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .quantization import QuantizationMethod
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -13,29 +10,28 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Finds all available modules to apply lora or galore.
|
Finds all available modules to apply lora or galore.
|
||||||
"""
|
"""
|
||||||
quantization_method = getattr(model, "quantization_method", None)
|
forbidden_modules = {"lm_head"}
|
||||||
if quantization_method is None:
|
|
||||||
linear_cls = torch.nn.Linear
|
|
||||||
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
|
|
||||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
|
||||||
else:
|
|
||||||
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
|
||||||
|
|
||||||
output_layer_names = ["lm_head"]
|
|
||||||
if model.config.model_type == "chatglm":
|
if model.config.model_type == "chatglm":
|
||||||
output_layer_names.append("output_layer")
|
forbidden_modules.add("output_layer")
|
||||||
elif model.config.model_type == "internlm2":
|
elif model.config.model_type == "internlm2":
|
||||||
output_layer_names.append("output")
|
forbidden_modules.add("output")
|
||||||
|
elif model.config.model_type in ["llava", "paligemma"]:
|
||||||
|
forbidden_modules.add("multi_modal_projector")
|
||||||
|
|
||||||
|
if freeze_vision_tower:
|
||||||
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
module_names = set()
|
module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
|
if any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||||
|
|
|
@ -35,6 +35,8 @@ class QuantizationMethod(str, Enum):
|
||||||
AWQ = "awq"
|
AWQ = "awq"
|
||||||
AQLM = "aqlm"
|
AQLM = "aqlm"
|
||||||
QUANTO = "quanto"
|
QUANTO = "quanto"
|
||||||
|
EETQ = "eetq"
|
||||||
|
HQQ = "hqq"
|
||||||
|
|
||||||
|
|
||||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import TYPE_CHECKING, Tuple, List
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.models
|
import transformers.models
|
||||||
|
@ -82,8 +82,3 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||||
|
|
||||||
|
|
||||||
def filter_vision_tower_linear(target_modules: List[str]) -> str:
|
|
||||||
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
|
|
||||||
return target_modules
|
|
||||||
|
|
Loading…
Reference in New Issue