improve quantization

This commit is contained in:
hiyouga 2023-12-20 18:27:16 +08:00
parent c4a3977ad7
commit 624cc21281
1 changed files with 24 additions and 22 deletions

View File

@ -47,33 +47,18 @@ def configure_quantization(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
):
r"""
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if getattr(config, "quantization_config", None): # gptq or awq
model_args.quantization_bit = None # remove bnb quantization
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
if model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if finetuning_args.export_quantization_bit is not None: # gptq
elif finetuning_args.export_quantization_bit is not None: # gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
@ -90,6 +75,23 @@ def configure_quantization(
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if model_args.rope_scaling is not None: