fix llava qlora

This commit is contained in:
hiyouga 2024-04-26 18:00:23 +08:00
parent cd3a960f81
commit fc67b736ba
6 changed files with 55 additions and 10 deletions

View File

@ -1,4 +1,5 @@
#!/bin/bash
# NEED TO run `merge.sh` before using this script
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
--model_name_or_path ../../models/llama2-7b-sft \

View File

@ -14,10 +14,23 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
outputs.append(os.path.join(data_args.dataset_dir, image))
else:
outputs.append(image)
return outputs
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
@ -47,11 +60,7 @@ def convert_alpaca(
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs
@ -60,6 +69,7 @@ def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
@ -94,11 +104,7 @@ def convert_sharegpt(
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
return outputs

View File

@ -182,6 +182,9 @@ class ModelArguments:
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]

View File

@ -323,6 +323,9 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.visual_inputs:
raise ValueError("vLLM engine does not support MLLM yet. Stay tuned.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)

View File

@ -15,6 +15,7 @@ from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.visual import autocast_projector_dtype
if TYPE_CHECKING:
@ -92,6 +93,9 @@ def patch_model(
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable:
prepare_model_for_training(model, model_args)
add_z3_leaf_module(model)

View File

@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Tuple
import torch
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)