diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 87bad577..530869d5 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -146,22 +146,25 @@ def load_model_and_tokenizer( else: logger.warning("Current model does not support shift short attention.") + # Quantization configurations (using gptq or awq) + if getattr(config, "quantization_config", None): + if model_args.quantization_bit is not None: # remove bnb quantization + model_args.quantization_bit = None + config_kwargs["device_map"] = {"": get_current_device()} + quantization_config = getattr(config, "quantization_config", None) + logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1))) + # Quantization configurations (using bitsandbytes library) if model_args.quantization_bit is not None: - if getattr(config, "quantization_config", None): - raise ValueError("Remove `quantization_bit` if you are using a quantized model.") - if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["load_in_8bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) if model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - config_kwargs["load_in_4bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=model_args.compute_dtype, diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 08eea563..1eab538d 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -22,7 +22,11 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": Dispatches a pre-trained model to GPUs with balanced memory. Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 """ - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing + if ( + getattr(model, "is_loaded_in_8bit", False) # bnb + or getattr(model, "is_loaded_in_4bit", False) # bnb + or getattr(model.config, "quantization_config", None) # gptq or awq + ): # already set on current device return model if torch.cuda.device_count() > 1: