improve autogptq integration

This commit is contained in:
hiyouga 2024-06-26 22:11:44 +08:00
parent 8d6cd69ac4
commit addca926de
2 changed files with 27 additions and 16 deletions

View File

@ -42,7 +42,7 @@ extra_require = {
"vllm": ["vllm>=0.4.3"], "vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],
"badam": ["badam>=1.2.1"], "badam": ["badam>=1.2.1"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"], "awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],

View File

@ -57,9 +57,9 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq" HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r""" r"""
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 Prepares the dataset to perform AutoGPTQ.
""" """
if os.path.isfile(model_args.export_quantization_dataset): if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
@ -68,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
data_path = model_args.export_quantization_dataset data_path = model_args.export_quantization_dataset
data_files = None data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) dataset = load_dataset(
maxlen = model_args.export_quantization_maxlen path=data_path,
data_files=data_files,
split="train",
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
)
samples = [] samples = []
maxlen = model_args.export_quantization_maxlen
for _ in range(model_args.export_quantization_nsamples): for _ in range(model_args.export_quantization_nsamples):
n_try = 0
while True: while True:
if n_try > 100:
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
sample_idx = random.randint(0, len(dataset) - 1) sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen: n_try += 1
if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
samples.append({"input_ids": input_ids, "attention_mask": attention_mask})
return samples return samples
@ -119,21 +131,20 @@ def configure_quantization(
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm": if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.") raise ValueError("ChatGLM model is not supported yet.")
init_kwargs["quantization_config"] = GPTQConfig( init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit, bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args), dataset=_get_quantization_dataset(tokenizer, model_args),
) )
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
@ -150,9 +161,9 @@ def configure_quantization(
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
) )
# assign device map if: # Do not assign device map if:
# 1. not deepspeed zero3 and not fsdp # 1. deepspeed zero3 or fsdp (train)
# 2. not auto quantization device map # 2. auto quantization device map (inference)
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4: if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
@ -161,4 +172,4 @@ def configure_quantization(
else: else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))