add yi-vl

This commit is contained in:
BUAADreamer 2024-05-13 23:28:28 +08:00
parent b033232aea
commit 64dac4085e
5 changed files with 53 additions and 4 deletions

View File

@ -856,6 +856,20 @@ _register_template(
)
_register_template(
name="yi-vl",
format_user=StringFormatter(slots=["### Human:\n{{content}}\n### Assistant: "]),
stop_words=["###"],
default_system=(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. "
"Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
"这是一个好奇的人类和一个人工智能助手之间的对话。"
"假设你扮演这个AI助手的角色。仔细阅读所有的图像并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
),
)
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),

View File

@ -16,7 +16,7 @@ from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import prepare_valuehead_model
from .utils.visual import autocast_projector_dtype, configure_hidden_size
from .utils.visual import autocast_projector_dtype, configure_hidden_size, configure_visual
if TYPE_CHECKING:
@ -50,6 +50,7 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_hidden_size(config)
configure_visual(config, model_args)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)

View File

@ -1,12 +1,14 @@
from typing import TYPE_CHECKING, Tuple
import torch
import transformers
from torch import nn
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel, LlavaConfig
from ...hparams import ModelArguments
@ -31,3 +33,25 @@ def autocast_projector_dtype(
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
class LlavaMultiModalProjectorYiVL(nn.Module):
def __init__(self, config: "LlavaConfig"):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.act = nn.GELU()
self.proj = nn.Sequential(*[self.linear_1, self.linear_2, self.act, self.linear_3, self.linear_4])
def forward(self, image_features):
hidden_states = self.proj(image_features)
return hidden_states
def configure_visual(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
logger = get_logger(__name__)
if model_args.visual_inputs and "Yi" in getattr(config.text_config, "_name_or_path", None):
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorYiVL
logger.info("Patched Multimodal Projector for Yi-VL.")

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainer, ProcessorMixin
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
@ -26,9 +26,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
def __init__(self, finetuning_args: "FinetuningArguments", processor: "ProcessorMixin", **kwargs) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self.processor = processor
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
@ -120,3 +121,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
for label, pred in zip(decoded_labels, decoded_preds):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
super().save_model(output_dir, _internal_call)
if self.processor is not None:
if output_dir is None:
output_dir = self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)

View File

@ -30,6 +30,7 @@ def run_sft(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
@ -55,6 +56,7 @@ def run_sft(
model=model,
args=training_args,
finetuning_args=finetuning_args,
processor=processor,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,