This commit is contained in:
hiyouga 2024-06-17 18:17:48 +08:00
parent 72471ee046
commit e2665e71c7
3 changed files with 25 additions and 14 deletions

View File

@ -281,12 +281,22 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if is_trainable and getattr(model, "quantization_method", None) and finetuning_args.finetuning_type != "lora":
if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
# cast trainable parameters to float32 if:
# 1. is_trainable and quantization_bit is not None (qlora)
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
# 3. is_trainable and not pure_bf16 and not badam
if not is_trainable:
cast_trainable_params_to_fp32 = False
elif is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
elif model_args.quantization_bit is None and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
):
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False
else:

View File

@ -1,6 +1,7 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Optimum library.
# This code is inspired by the HuggingFace's Transformers and Optimum library.
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -96,10 +97,7 @@ def configure_quantization(
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
raise ValueError("DeepSpeed ZeRO-3 is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
@ -152,15 +150,15 @@ def configure_quantization(
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
)
# assign device map if:
# 1. not deepspeed zero3 and not fsdp
# 2. not auto quantization device map
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
else:
init_kwargs["device_map"] = {"": get_current_device()}
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

View File

@ -89,7 +89,10 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
# cast data type of the model if:
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
# 2. fsdp + qlora
if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()):
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True