fix llava qlora
This commit is contained in:
parent
cd3a960f81
commit
fc67b736ba
|
@ -1,4 +1,5 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
# NEED TO run `merge.sh` before using this script
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \
|
||||||
--model_name_or_path ../../models/llama2-7b-sft \
|
--model_name_or_path ../../models/llama2-7b-sft \
|
||||||
|
|
|
@ -14,10 +14,23 @@ if TYPE_CHECKING:
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||||
|
outputs = []
|
||||||
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
|
for image in images:
|
||||||
|
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
||||||
|
outputs.append(os.path.join(data_args.dataset_dir, image))
|
||||||
|
else:
|
||||||
|
outputs.append(image)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def convert_alpaca(
|
def convert_alpaca(
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[Any]]:
|
) -> Dict[str, List[Any]]:
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||||
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
for i in range(len(examples[dataset_attr.prompt])):
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
prompt = []
|
prompt = []
|
||||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||||
|
@ -47,11 +60,7 @@ def convert_alpaca(
|
||||||
outputs["response"].append(response)
|
outputs["response"].append(response)
|
||||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
outputs["tools"].append("")
|
outputs["tools"].append("")
|
||||||
outputs["images"].append(
|
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||||
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
|
|
||||||
if dataset_attr.images
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -60,6 +69,7 @@ def convert_sharegpt(
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[Any]]:
|
) -> Dict[str, List[Any]]:
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||||
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
tag_mapping = {
|
tag_mapping = {
|
||||||
dataset_attr.user_tag: Role.USER.value,
|
dataset_attr.user_tag: Role.USER.value,
|
||||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||||
|
@ -94,11 +104,7 @@ def convert_sharegpt(
|
||||||
outputs["response"].append(aligned_messages[-1:])
|
outputs["response"].append(aligned_messages[-1:])
|
||||||
outputs["system"].append(system)
|
outputs["system"].append(system)
|
||||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||||
outputs["images"].append(
|
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||||
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
|
|
||||||
if dataset_attr.images
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
|
@ -182,6 +182,9 @@ class ModelArguments:
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
|
||||||
|
if self.visual_inputs and self.use_unsloth:
|
||||||
|
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
|
||||||
|
|
||||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
|
|
|
@ -323,6 +323,9 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
if model_args.visual_inputs:
|
if model_args.visual_inputs:
|
||||||
raise ValueError("vLLM engine does not support MLLM yet. Stay tuned.")
|
raise ValueError("vLLM engine does not support MLLM yet. Stay tuned.")
|
||||||
|
|
||||||
|
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
||||||
|
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args)
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from .utils.longlora import configure_longlora
|
||||||
from .utils.moe import add_z3_leaf_module, configure_moe
|
from .utils.moe import add_z3_leaf_module, configure_moe
|
||||||
from .utils.quantization import configure_quantization
|
from .utils.quantization import configure_quantization
|
||||||
from .utils.rope import configure_rope
|
from .utils.rope import configure_rope
|
||||||
|
from .utils.visual import autocast_projector_dtype
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -92,6 +93,9 @@ def patch_model(
|
||||||
if model_args.resize_vocab:
|
if model_args.resize_vocab:
|
||||||
resize_embedding_layer(model, tokenizer)
|
resize_embedding_layer(model, tokenizer)
|
||||||
|
|
||||||
|
if model_args.visual_inputs:
|
||||||
|
autocast_projector_dtype(model, model_args)
|
||||||
|
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
prepare_model_for_training(model, model_args)
|
prepare_model_for_training(model, model_args)
|
||||||
add_z3_leaf_module(model)
|
add_z3_leaf_module(model)
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def autocast_projector_dtype(
|
||||||
|
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||||
|
) -> None:
|
||||||
|
def _mm_projector_forward_post_hook(
|
||||||
|
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
return output.to(model_args.compute_dtype)
|
||||||
|
|
||||||
|
if hasattr(model, mm_projector_name):
|
||||||
|
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||||
|
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||||
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
Loading…
Reference in New Issue