fix fsdp model loading

This commit is contained in:
hiyouga 2024-05-15 16:32:28 +08:00
parent 11bf282dcc
commit 008e3b3b10
2 changed files with 4 additions and 2 deletions

View File

@ -6,6 +6,7 @@ import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
@ -69,7 +70,7 @@ def patch_config(
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled():
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs and model_args.device_map:

View File

@ -7,6 +7,7 @@ import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE
@ -133,7 +134,7 @@ def configure_quantization(
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")