diff --git a/sites/paligemma-pt.yaml b/sites/paligemma-pt.yaml new file mode 100644 index 00000000..4305cf5f --- /dev/null +++ b/sites/paligemma-pt.yaml @@ -0,0 +1,49 @@ +# model +model_name_or_path: google/paligemma-3b-mix-448 +visual_inputs: true +tune_mm_proj: true +#print_param_status: true + +# method +stage: sft +do_train: true +finetuning_type: full + +# ddp +ddp_timeout: 180000000 +deepspeed: examples/deepspeed/ds_z2_offload_config.json + +# dataset +dataset: mllm_pt_demo +dataset_dir: data +template: gemma +cutoff_len: 2048 +max_samples: 3 +#val_size: 0.0001 +overwrite_cache: true +preprocessing_num_workers: 16 + +# output +output_dir: saves/paligemma/full/sft_llava_pt_test +logging_steps: 1 +save_steps: 50 +plot_loss: true +overwrite_output_dir: true +#save_strategy: epoch +#save_total_limit: 2 + +# train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 16 +learning_rate: 0.00001 +num_train_epochs: 100 +lr_scheduler_type: cosine +warmup_steps: 0.1 +#bf16: true +pure_bf16: true + +# eval +do_eval: false +#per_device_eval_batch_size: 1 +#evaluation_strategy: steps +#eval_steps: 500 diff --git a/sites/paligemma.yaml b/sites/paligemma.yaml new file mode 100644 index 00000000..f3257cfc --- /dev/null +++ b/sites/paligemma.yaml @@ -0,0 +1,49 @@ +# model +model_name_or_path: google/paligemma-3b-mix-448 +visual_inputs: true +#print_param_status: true +use_fast_tokenizer: false + +# method +stage: sft +do_train: true +finetuning_type: full + +# ddp +ddp_timeout: 180000000 +deepspeed: examples/deepspeed/ds_z2_offload_config.json + +# dataset +dataset: mllm_demo +dataset_dir: data +template: gemma +cutoff_len: 2048 +max_samples: 3 +#val_size: 0.0001 +overwrite_cache: true +preprocessing_num_workers: 16 + +# output +output_dir: saves/paligemma/full/sft_llava_1k +logging_steps: 1 +save_steps: 50 +plot_loss: true +overwrite_output_dir: true +#save_strategy: epoch +#save_total_limit: 2 + +# train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 16 +learning_rate: 0.00001 +num_train_epochs: 100 +lr_scheduler_type: cosine +warmup_steps: 0.1 +#bf16: true +pure_bf16: true + +# eval +do_eval: false +#per_device_eval_batch_size: 1 +#evaluation_strategy: steps +#eval_steps: 500 diff --git a/sites/paligemma_lora.yaml b/sites/paligemma_lora.yaml new file mode 100644 index 00000000..0693a6ae --- /dev/null +++ b/sites/paligemma_lora.yaml @@ -0,0 +1,40 @@ +### model +model_name_or_path: google/paligemma-3b-mix-448 +visual_inputs: true +use_fast_tokenizer: false + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: q_proj,v_proj + +### dataset +dataset: mllm_demo +template: gemma +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/paligemma/lora/sft_mllm +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 0.0001 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_steps: 0.1 +fp16: true + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +evaluation_strategy: steps +eval_steps: 500 diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index f37f3bbb..015db8a0 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -10,6 +10,7 @@ from ..extras.logging import get_logger from .utils.misc import find_all_linear_modules, find_expanded_modules from .utils.quantization import QuantizationMethod from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model +from .utils.visual import filter_vision_tower_linear if TYPE_CHECKING: @@ -58,6 +59,9 @@ def init_adapter( if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model model.vision_tower.requires_grad_(False) + if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj + model.language_model.requires_grad_(False) + if finetuning_args.finetuning_type == "freeze" and is_trainable: logger.info("Fine-tuning method: Freeze") num_layers = ( @@ -180,6 +184,9 @@ def init_adapter( if finetuning_args.use_llama_pro: target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) + if model_args.visual_inputs: + target_modules = filter_vision_tower_linear(target_modules) + if ( finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index d9784593..49b347d5 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -163,11 +163,6 @@ def load_model( else: model.train() - if model_args.visual_inputs and model_args.tune_mm_proj: - lm_params = [param for name, param in model.named_parameters() if "language_model" in name] - for param in lm_params: - param.requires_grad_(False) - trainable_params, all_param = count_parameters(model) if is_trainable: param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( diff --git a/src/llamafactory/model/utils/visual.py b/src/llamafactory/model/utils/visual.py index c8260b7f..a91777ba 100644 --- a/src/llamafactory/model/utils/visual.py +++ b/src/llamafactory/model/utils/visual.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Tuple, List import torch import transformers.models @@ -82,3 +82,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None: if getattr(config, "is_yi_vl_derived_model", None): logger.info("Detected Yi-VL model, applying projector patch.") transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL + + +def filter_vision_tower_linear(target_modules: List[str]) -> str: + target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*" + return target_modules