From addca926de42f91366185a47eb8e777ed44a8e77 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 26 Jun 2024 22:11:44 +0800 Subject: [PATCH] improve autogptq integration --- setup.py | 2 +- .../model/model_utils/quantization.py | 41 ++++++++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 64f50a87..8254b6d4 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ extra_require = { "vllm": ["vllm>=0.4.3"], "galore": ["galore-torch"], "badam": ["badam>=1.2.1"], - "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], + "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], "qwen": ["transformers_stream_generator"], diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 5251f84f..fab61cb8 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -57,9 +57,9 @@ class QuantizationMethod(str, Enum): HQQ = "hqq" -def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: r""" - TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 + Prepares the dataset to perform AutoGPTQ. """ if os.path.isfile(model_args.export_quantization_dataset): data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) @@ -68,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod data_path = model_args.export_quantization_dataset data_files = None - dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) - maxlen = model_args.export_quantization_maxlen + dataset = load_dataset( + path=data_path, + data_files=data_files, + split="train", + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + ) samples = [] + maxlen = model_args.export_quantization_maxlen for _ in range(model_args.export_quantization_nsamples): + n_try = 0 while True: + if n_try > 100: + raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") + sample_idx = random.randint(0, len(dataset) - 1) - sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") - if sample["input_ids"].size(1) >= maxlen: + sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + n_try += 1 + if sample["input_ids"].size(1) > maxlen: break # TODO: fix large maxlen word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] - samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) + attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen] + samples.append({"input_ids": input_ids, "attention_mask": attention_mask}) return samples @@ -119,21 +131,20 @@ def configure_quantization( logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) elif model_args.export_quantization_bit is not None: # auto-gptq - require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") from accelerate.utils import get_max_memory if getattr(config, "model_type", None) == "chatglm": - raise ValueError("ChatGLM model is not supported.") + raise ValueError("ChatGLM model is not supported yet.") init_kwargs["quantization_config"] = GPTQConfig( bits=model_args.export_quantization_bit, - tokenizer=tokenizer, dataset=_get_quantization_dataset(tokenizer, model_args), ) init_kwargs["device_map"] = "auto" init_kwargs["max_memory"] = get_max_memory() - logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) + logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) elif model_args.quantization_bit is not None: # bnb if model_args.quantization_bit == 8: @@ -150,9 +161,9 @@ def configure_quantization( bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora ) - # assign device map if: - # 1. not deepspeed zero3 and not fsdp - # 2. not auto quantization device map + # Do not assign device map if: + # 1. deepspeed zero3 or fsdp (train) + # 2. auto quantization device map (inference) if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": if model_args.quantization_bit != 4: raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") @@ -161,4 +172,4 @@ def configure_quantization( else: init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))