diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index b28a23d0..8625f3e1 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -6,6 +6,7 @@ import torch from peft import PeftModel from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger from ..extras.misc import infer_optim_dtype @@ -69,7 +70,7 @@ def patch_config( setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn init_kwargs["torch_dtype"] = model_args.compute_dtype - if not is_deepspeed_zero3_enabled(): + if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage if init_kwargs["low_cpu_mem_usage"]: if "device_map" not in init_kwargs and model_args.device_map: diff --git a/src/llmtuner/model/utils/quantization.py b/src/llmtuner/model/utils/quantization.py index 3cf159c1..95412e7c 100644 --- a/src/llmtuner/model/utils/quantization.py +++ b/src/llmtuner/model/utils/quantization.py @@ -7,6 +7,7 @@ import torch from datasets import load_dataset from transformers import BitsAndBytesConfig, GPTQConfig from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled from transformers.utils.versions import require_version from ...extras.constants import FILEEXT2TYPE @@ -133,7 +134,7 @@ def configure_quantization( bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora ) - if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": + if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": if model_args.quantization_bit != 4: raise ValueError("Only 4-bit quantized model can use auto device map.")