From efc345c4b0095ec959ea23bbe54c344278780cbe Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 15 Apr 2024 15:32:58 +0800 Subject: [PATCH] fix #3273 --- src/llmtuner/hparams/model_args.py | 4 ++++ src/llmtuner/hparams/parser.py | 6 +++--- src/llmtuner/model/adapter.py | 12 +++++++++--- src/llmtuner/model/patcher.py | 4 ++-- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index be71d32f..514c8714 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -129,6 +129,10 @@ class ModelArguments: default=1, metadata={"help": "The file shard size (in GB) of the exported model."}, ) + export_device: str = field( + default="cpu", + metadata={"help": "The device used in model export."}, + ) export_quantization_bit: Optional[int] = field( default=None, metadata={"help": "The number of bits to quantize the exported model."}, diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 4abd3f03..1865ff17 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -10,7 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_torch_bf16_gpu_available from ..extras.logging import get_logger -from ..extras.misc import check_dependencies +from ..extras.misc import check_dependencies, get_current_device from ..extras.packages import is_unsloth_available from .data_args import DataArguments from .evaluation_args import EvaluationArguments @@ -235,6 +235,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: elif training_args.fp16: model_args.compute_dtype = torch.float16 + model_args.device_map = {"": get_current_device()} model_args.model_max_length = data_args.cutoff_len data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" @@ -278,8 +279,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _verify_model_args(model_args, finetuning_args) if model_args.export_dir is not None: - model_args.device_map = {"": "cpu"} - model_args.compute_dtype = torch.float32 + model_args.device_map = {"": torch.device(model_args.export_device)} else: model_args.device_map = "auto" diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index eb6d3878..4bb4057d 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -32,6 +32,9 @@ def init_adapter( logger.info("Adapter is not found at evaluation, load the base model.") return model + if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): + raise ValueError("You can only use lora for quantized models.") + if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") if not finetuning_args.pure_bf16: @@ -129,9 +132,12 @@ def init_adapter( if finetuning_args.use_llama_pro: target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) - if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None: - if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES: - raise ValueError("DoRA is not compatible with PTQ-quantized models.") + if ( + finetuning_args.use_dora + and getattr(model, "quantization_method", None) is not None + and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + ): + raise ValueError("DoRA is not compatible with PTQ-quantized models.") peft_kwargs = { "r": finetuning_args.lora_rank, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index a23d0ef3..fe707af7 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -323,8 +323,8 @@ def patch_config( if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage if init_kwargs["low_cpu_mem_usage"]: - if "device_map" not in init_kwargs: - init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} + if "device_map" not in init_kwargs and model_args.device_map: + init_kwargs["device_map"] = model_args.device_map if init_kwargs["device_map"] == "auto": init_kwargs["offload_folder"] = model_args.offload_folder