fix #4326
This commit is contained in:
parent
72471ee046
commit
e2665e71c7
|
@ -281,12 +281,22 @@ def init_adapter(
|
|||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
if is_trainable and getattr(model, "quantization_method", None) and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||
if is_trainable and getattr(model, "quantization_method", None) is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
|
||||
# 3. is_trainable and not pure_bf16 and not badam
|
||||
if not is_trainable:
|
||||
cast_trainable_params_to_fp32 = False
|
||||
elif is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
elif model_args.quantization_bit is None and (
|
||||
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
|
||||
):
|
||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's Optimum library.
|
||||
# This code is inspired by the HuggingFace's Transformers and Optimum library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
|
||||
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -96,10 +97,7 @@ def configure_quantization(
|
|||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||
|
||||
if model_args.quantization_device_map != "auto":
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with PTQ-quantized models.")
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
@ -152,15 +150,15 @@ def configure_quantization(
|
|||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
|
||||
# assign device map if:
|
||||
# 1. not deepspeed zero3 and not fsdp
|
||||
# 2. not auto quantization device map
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()}
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
|
|
@ -89,7 +89,10 @@ def patch_config(
|
|||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): # cast dtype and device if not use zero3 or fsdp
|
||||
# cast data type of the model if:
|
||||
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
|
||||
# 2. fsdp + qlora
|
||||
if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()):
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
||||
|
|
Loading…
Reference in New Issue