add quant checks

This commit is contained in:
hiyouga 2024-06-27 01:12:25 +08:00
parent d417e63f92
commit 96a5044394
1 changed files with 11 additions and 2 deletions

View File

@ -108,8 +108,11 @@ def configure_quantization(
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with PTQ-quantized models.")
if model_args.quantization_bit is not None:
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
@ -182,6 +185,9 @@ def configure_quantization(
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("hqq", "To fix: pip install hqq")
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
@ -191,6 +197,9 @@ def configure_quantization(
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))