From c0be617195f43d972681dd59727857b1247eeb7e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 29 Feb 2024 18:32:54 +0800 Subject: [PATCH] fix #2642 --- src/llmtuner/model/loader.py | 9 +-------- src/llmtuner/model/patcher.py | 7 +++++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 29d213a7..9d453637 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Optional, Tuple from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from transformers.integrations import is_deepspeed_zero3_enabled from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger @@ -77,13 +76,7 @@ def load_model_and_tokenizer( logger.warning("Unsloth does not support loading adapters.") if model is None: - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - config=config, - torch_dtype=model_args.compute_dtype, - low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), - **config_kwargs, - ) + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs) patch_model(model, tokenizer, model_args, is_trainable) register_autoclass(config, model, tokenizer) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index c1d14c91..054c7de7 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -163,7 +163,6 @@ def _configure_quantization( if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - config_kwargs["device_map"] = {"": get_current_device()} quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4: quantization_config["use_exllama"] = False # disable exllama @@ -214,7 +213,6 @@ def _configure_quantization( bnb_4bit_quant_type=model_args.quantization_type, ) - config_kwargs["device_map"] = {"": get_current_device()} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) @@ -284,6 +282,11 @@ def patch_config( _configure_quantization(config, tokenizer, model_args, config_kwargs) + config_kwargs["torch_dtype"] = model_args.compute_dtype + if not is_deepspeed_zero3_enabled(): + config_kwargs["device_map"] = {"": get_current_device()} + config_kwargs["low_cpu_mem_usage"] = True + def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool