fix gptq model inference

This commit is contained in:
hiyouga 2023-12-01 23:34:14 +08:00
parent 0cb260f453
commit 01e6c539b0
2 changed files with 13 additions and 6 deletions

View File

@ -146,22 +146,25 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using gptq or awq)
if getattr(config, "quantization_config", None):
if model_args.quantization_bit is not None: # remove bnb quantization
model_args.quantization_bit = None
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
# Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None:
if getattr(config, "quantization_config", None):
raise ValueError("Remove `quantization_bit` if you are using a quantized model.")
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,

View File

@ -22,7 +22,11 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
if (
getattr(model, "is_loaded_in_8bit", False) # bnb
or getattr(model, "is_loaded_in_4bit", False) # bnb
or getattr(model.config, "quantization_config", None) # gptq or awq
): # already set on current device
return model
if torch.cuda.device_count() > 1: