diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index f716102f..b20c9203 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -856,6 +856,21 @@ _register_template( ) +_register_template( + name="yi_vl", + format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( + "This is a chat between an inquisitive human and an AI assistant. " + "Assume the role of the AI assistant. Read all the images carefully, " + "and respond to the human's questions with informative, helpful, detailed and polite answers. " + "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。" + "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n" + ), + stop_words=["###"], +) + + _register_template( name="yuan", format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 8625f3e1..603e3c9e 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -18,7 +18,7 @@ from .utils.moe import add_z3_leaf_module, configure_moe from .utils.quantization import configure_quantization from .utils.rope import configure_rope from .utils.valuehead import prepare_valuehead_model -from .utils.visual import autocast_projector_dtype, configure_hidden_size +from .utils.visual import autocast_projector_dtype, configure_visual_model if TYPE_CHECKING: @@ -55,7 +55,7 @@ def patch_config( configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) - configure_hidden_size(config) + configure_visual_model(config) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index b29a9ba5..9a5134ff 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -1,12 +1,14 @@ from typing import TYPE_CHECKING, Tuple import torch +import transformers.models +from transformers.activations import ACT2FN from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel from ...hparams import ModelArguments @@ -14,9 +16,23 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def configure_hidden_size(config: "PretrainedConfig") -> None: - if getattr(config, "model_type", None) == "llava": - setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) +class LlavaMultiModalProjector(torch.nn.Module): + def __init__(self, config: "LlavaConfig"): + super().__init__() + + self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_3(hidden_states) + hidden_states = self.linear_4(hidden_states) + return hidden_states def autocast_projector_dtype( @@ -31,3 +47,11 @@ def autocast_projector_dtype( logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook) + + +def configure_visual_model(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "llava": + setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) + + if getattr(config, "is_yi_vl_derived_model", None): + transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjector diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index def427fd..35671e1b 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -13,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + from transformers import ProcessorMixin from transformers.trainer import PredictionOutput from ...hparams import FinetuningArguments @@ -26,9 +27,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. """ - def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: + def __init__( + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + ) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args + self.processor = processor if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -45,6 +49,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: + super()._save(output_dir, state_dict) + if self.processor is not None: + output_dir = output_dir if output_dir is not None else self.args.output_dir + getattr(self.processor, "image_processor").save_pretrained(output_dir) + def prediction_step( self, model: "torch.nn.Module", diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 4a9775b4..d9d7c8e9 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -55,10 +55,10 @@ def run_sft( model=model, args=training_args, finetuning_args=finetuning_args, - tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + **tokenizer_module, **split_dataset(dataset, data_args, training_args), )