From e2665e71c7428014d46d91542b01a58c1064d05a Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 17 Jun 2024 18:17:48 +0800 Subject: [PATCH] fix #4326 --- src/llamafactory/model/adapter.py | 16 +++++++++++++--- .../model/model_utils/quantization.py | 18 ++++++++---------- src/llamafactory/model/patcher.py | 5 ++++- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index a8f3a256..34518878 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -281,12 +281,22 @@ def init_adapter( Note that the trainable parameters must be cast to float32. """ - if is_trainable and getattr(model, "quantization_method", None) and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantized models can only be used for the LoRA tuning.") + if is_trainable and getattr(model, "quantization_method", None) is not None: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Quantized models can only be used for the LoRA tuning.") + if finetuning_args.pissa_init: + raise ValueError("Cannot initialize PiSSA adapter on quantized models.") + + # cast trainable parameters to float32 if: + # 1. is_trainable and quantization_bit is not None (qlora) + # 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32) + # 3. is_trainable and not pure_bf16 and not badam if not is_trainable: cast_trainable_params_to_fp32 = False - elif is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: + elif model_args.quantization_bit is None and ( + is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam + ): logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") cast_trainable_params_to_fp32 = False else: diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 0a0fca34..5251f84f 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -1,6 +1,7 @@ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by the HuggingFace's Optimum library. +# This code is inspired by the HuggingFace's Transformers and Optimum library. +# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py # https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -96,10 +97,7 @@ def configure_quantization( """ if getattr(config, "quantization_config", None): # ptq if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") - - if model_args.quantization_device_map != "auto": - init_kwargs["device_map"] = {"": get_current_device()} + raise ValueError("DeepSpeed ZeRO-3 is incompatible with PTQ-quantized models.") quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") @@ -152,15 +150,15 @@ def configure_quantization( bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora ) + # assign device map if: + # 1. not deepspeed zero3 and not fsdp + # 2. not auto quantization device map if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": if model_args.quantization_bit != 4: - raise ValueError("Only 4-bit quantized model can use auto device map.") + raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") - require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") - require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") - init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype else: - init_kwargs["device_map"] = {"": get_current_device()} + init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 053516e4..8fa17d08 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -89,7 +89,10 @@ def patch_config( # deepspeed zero3 is not compatible with low_cpu_mem_usage init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) - if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp + # cast data type of the model if: + # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) + # 2. fsdp + qlora + if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()): init_kwargs["torch_dtype"] = model_args.compute_dtype if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True