fix llava qlora
This commit is contained in:
parent
cd3a960f81
commit
fc67b736ba
|
@ -1,4 +1,5 @@
|
|||
#!/bin/bash
|
||||
# NEED TO run `merge.sh` before using this script
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||
--model_name_or_path ../../models/llama2-7b-sft \
|
||||
|
|
|
@ -14,10 +14,23 @@ if TYPE_CHECKING:
|
|||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
outputs = []
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
||||
outputs.append(os.path.join(data_args.dataset_dir, image))
|
||||
else:
|
||||
outputs.append(image)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
|
@ -47,11 +60,7 @@ def convert_alpaca(
|
|||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["images"].append(
|
||||
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
|
||||
if dataset_attr.images
|
||||
else []
|
||||
)
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -60,6 +69,7 @@ def convert_sharegpt(
|
|||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
|
@ -94,11 +104,7 @@ def convert_sharegpt(
|
|||
outputs["response"].append(aligned_messages[-1:])
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(
|
||||
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
|
||||
if dataset_attr.images
|
||||
else []
|
||||
)
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -182,6 +182,9 @@ class ModelArguments:
|
|||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.visual_inputs and self.use_unsloth:
|
||||
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
|
||||
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
|
|
|
@ -323,6 +323,9 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||
if model_args.visual_inputs:
|
||||
raise ValueError("vLLM engine does not support MLLM yet. Stay tuned.")
|
||||
|
||||
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
||||
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from .utils.longlora import configure_longlora
|
|||
from .utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .utils.quantization import configure_quantization
|
||||
from .utils.rope import configure_rope
|
||||
from .utils.visual import autocast_projector_dtype
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -92,6 +93,9 @@ def patch_model(
|
|||
if model_args.resize_vocab:
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
if model_args.visual_inputs:
|
||||
autocast_projector_dtype(model, model_args)
|
||||
|
||||
if is_trainable:
|
||||
prepare_model_for_training(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def autocast_projector_dtype(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||
) -> None:
|
||||
def _mm_projector_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
if hasattr(model, mm_projector_name):
|
||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
Loading…
Reference in New Issue