From fc67b736ba3abd1a321c3498548d58de48921626 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 26 Apr 2024 18:00:23 +0800 Subject: [PATCH] fix llava qlora --- examples/merge_lora/quantize.sh | 1 + src/llmtuner/data/aligner.py | 26 ++++++++++++++++---------- src/llmtuner/hparams/model_args.py | 3 +++ src/llmtuner/hparams/parser.py | 3 +++ src/llmtuner/model/patcher.py | 4 ++++ src/llmtuner/model/utils/visual.py | 28 ++++++++++++++++++++++++++++ 6 files changed, 55 insertions(+), 10 deletions(-) create mode 100644 src/llmtuner/model/utils/visual.py diff --git a/examples/merge_lora/quantize.sh b/examples/merge_lora/quantize.sh index 143bce50..aeedbe66 100644 --- a/examples/merge_lora/quantize.sh +++ b/examples/merge_lora/quantize.sh @@ -1,4 +1,5 @@ #!/bin/bash +# NEED TO run `merge.sh` before using this script CUDA_VISIBLE_DEVICES=0 python ../../src/export_model.py \ --model_name_or_path ../../models/llama2-7b-sft \ diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index dc1de865..6bd12aad 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -14,10 +14,23 @@ if TYPE_CHECKING: from .parser import DatasetAttr +def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: + outputs = [] + if dataset_attr.load_from in ["script", "file"]: + for image in images: + if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)): + outputs.append(os.path.join(data_args.dataset_dir, image)) + else: + outputs.append(image) + + return outputs + + def convert_alpaca( examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" ) -> Dict[str, List[Any]]: outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} + convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) for i in range(len(examples[dataset_attr.prompt])): prompt = [] if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): @@ -47,11 +60,7 @@ def convert_alpaca( outputs["response"].append(response) outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["tools"].append("") - outputs["images"].append( - [os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]] - if dataset_attr.images - else [] - ) + outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) return outputs @@ -60,6 +69,7 @@ def convert_sharegpt( examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" ) -> Dict[str, List[Any]]: outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} + convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) tag_mapping = { dataset_attr.user_tag: Role.USER.value, dataset_attr.assistant_tag: Role.ASSISTANT.value, @@ -94,11 +104,7 @@ def convert_sharegpt( outputs["response"].append(aligned_messages[-1:]) outputs["system"].append(system) outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") - outputs["images"].append( - [os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]] - if dataset_attr.images - else [] - ) + outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) return outputs diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index be65cd27..ac70bb3c 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -182,6 +182,9 @@ class ModelArguments: if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + if self.visual_inputs and self.use_unsloth: + raise ValueError("Unsloth does not support MLLM yet. Stay tuned.") + if self.adapter_name_or_path is not None: # support merging multiple lora weights self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 715b8f95..aa046837 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -323,6 +323,9 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: if model_args.visual_inputs: raise ValueError("vLLM engine does not support MLLM yet. Stay tuned.") + if finetuning_args.stage == "rm" and model_args.visual_inputs: + raise ValueError("Reward server does not support MLLM yet. Stay tuned.") + _verify_model_args(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 5c3c31b3..94d99644 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -15,6 +15,7 @@ from .utils.longlora import configure_longlora from .utils.moe import add_z3_leaf_module, configure_moe from .utils.quantization import configure_quantization from .utils.rope import configure_rope +from .utils.visual import autocast_projector_dtype if TYPE_CHECKING: @@ -92,6 +93,9 @@ def patch_model( if model_args.resize_vocab: resize_embedding_layer(model, tokenizer) + if model_args.visual_inputs: + autocast_projector_dtype(model, model_args) + if is_trainable: prepare_model_for_training(model, model_args) add_z3_leaf_module(model) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py new file mode 100644 index 00000000..cb51301b --- /dev/null +++ b/src/llmtuner/model/utils/visual.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, Tuple + +import torch + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def autocast_projector_dtype( + model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector" +) -> None: + def _mm_projector_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" + ) -> "torch.Tensor": + return output.to(model_args.compute_dtype) + + if hasattr(model, mm_projector_name): + logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) + mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) + mm_projector.register_forward_hook(_mm_projector_forward_post_hook)