Merge pull request #3835 from BUAADreamer/main

fix some features in llava-style training
This commit is contained in:
hoshi-hiyouga 2024-05-27 20:23:45 +08:00 committed by GitHub
commit 838f2fb3e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 1 deletions

View File

@ -38,6 +38,20 @@
"assistant_tag": "assistant"
}
},
"mllm_pt_demo": {
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en"

View File

@ -85,6 +85,10 @@ class ModelArguments:
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
tune_mm_proj: bool = field(
default=False,
metadata={"help": "Whethor or not only finetune mm_projector for MLLM."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},

View File

@ -10,6 +10,7 @@ from ..extras.logging import get_logger
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .utils.visual import filter_vision_tower_linear
if TYPE_CHECKING:
@ -58,6 +59,9 @@ def init_adapter(
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
model.language_model.requires_grad_(False)
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = (
@ -180,6 +184,9 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if model_args.visual_inputs:
target_modules = filter_vision_tower_linear(target_modules)
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Tuple, List
import torch
import transformers.models
@ -82,3 +82,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def filter_vision_tower_linear(target_modules: List[str]) -> str:
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
return target_modules