add regex of only tune lm and mm_proj

This commit is contained in:
BUAADreamer 2024-05-27 18:59:00 +08:00
parent 4bc7c10c00
commit 57eb13b75d
6 changed files with 151 additions and 6 deletions

49
sites/paligemma-pt.yaml Normal file
View File

@ -0,0 +1,49 @@
# model
model_name_or_path: google/paligemma-3b-mix-448
visual_inputs: true
tune_mm_proj: true
#print_param_status: true
# method
stage: sft
do_train: true
finetuning_type: full
# ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z2_offload_config.json
# dataset
dataset: mllm_pt_demo
dataset_dir: data
template: gemma
cutoff_len: 2048
max_samples: 3
#val_size: 0.0001
overwrite_cache: true
preprocessing_num_workers: 16
# output
output_dir: saves/paligemma/full/sft_llava_pt_test
logging_steps: 1
save_steps: 50
plot_loss: true
overwrite_output_dir: true
#save_strategy: epoch
#save_total_limit: 2
# train
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
learning_rate: 0.00001
num_train_epochs: 100
lr_scheduler_type: cosine
warmup_steps: 0.1
#bf16: true
pure_bf16: true
# eval
do_eval: false
#per_device_eval_batch_size: 1
#evaluation_strategy: steps
#eval_steps: 500

49
sites/paligemma.yaml Normal file
View File

@ -0,0 +1,49 @@
# model
model_name_or_path: google/paligemma-3b-mix-448
visual_inputs: true
#print_param_status: true
use_fast_tokenizer: false
# method
stage: sft
do_train: true
finetuning_type: full
# ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z2_offload_config.json
# dataset
dataset: mllm_demo
dataset_dir: data
template: gemma
cutoff_len: 2048
max_samples: 3
#val_size: 0.0001
overwrite_cache: true
preprocessing_num_workers: 16
# output
output_dir: saves/paligemma/full/sft_llava_1k
logging_steps: 1
save_steps: 50
plot_loss: true
overwrite_output_dir: true
#save_strategy: epoch
#save_total_limit: 2
# train
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
learning_rate: 0.00001
num_train_epochs: 100
lr_scheduler_type: cosine
warmup_steps: 0.1
#bf16: true
pure_bf16: true
# eval
do_eval: false
#per_device_eval_batch_size: 1
#evaluation_strategy: steps
#eval_steps: 500

40
sites/paligemma_lora.yaml Normal file
View File

@ -0,0 +1,40 @@
### model
model_name_or_path: google/paligemma-3b-mix-448
visual_inputs: true
use_fast_tokenizer: false
### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: mllm_demo
template: gemma
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/paligemma/lora/sft_mllm
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 0.0001
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1
fp16: true
### eval
val_size: 0.1
per_device_eval_batch_size: 1
evaluation_strategy: steps
eval_steps: 500

View File

@ -10,6 +10,7 @@ from ..extras.logging import get_logger
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .utils.visual import filter_vision_tower_linear
if TYPE_CHECKING:
@ -58,6 +59,9 @@ def init_adapter(
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
model.language_model.requires_grad_(False)
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = (
@ -180,6 +184,9 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if model_args.visual_inputs:
target_modules = filter_vision_tower_linear(target_modules)
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None

View File

@ -163,11 +163,6 @@ def load_model(
else:
model.train()
if model_args.visual_inputs and model_args.tune_mm_proj:
lm_params = [param for name, param in model.named_parameters() if "language_model" in name]
for param in lm_params:
param.requires_grad_(False)
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Tuple, List
import torch
import transformers.models
@ -82,3 +82,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def filter_vision_tower_linear(target_modules: List[str]) -> str:
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
return target_modules