better dtype handle in loading
This commit is contained in:
parent
ddec9e1b84
commit
d9f190ff1e
|
@ -44,7 +44,7 @@ def init_adapter(
|
||||||
raise ValueError("You can only use lora for quantized models.")
|
raise ValueError("You can only use lora for quantized models.")
|
||||||
|
|
||||||
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.")
|
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||||
cast_trainable_params_to_fp32 = False
|
cast_trainable_params_to_fp32 = False
|
||||||
else:
|
else:
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info("Upcasting trainable params to float32.")
|
||||||
|
@ -122,6 +122,9 @@ def init_adapter(
|
||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||||
|
model.vision_tower.requires_grad_(False)
|
||||||
|
|
||||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
|
|
@ -170,6 +170,7 @@ def load_model(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
param_stats = "all params: {:d}".format(all_param)
|
param_stats = "all params: {:d}".format(all_param)
|
||||||
|
|
||||||
logger.info(param_stats)
|
logger.info(param_stats)
|
||||||
|
|
||||||
if model_args.print_param_status:
|
if model_args.print_param_status:
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
@ -66,13 +66,16 @@ def patch_config(
|
||||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
|
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||||
|
|
||||||
|
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||||
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||||
|
|
||||||
|
if deepspeed_config() is None and not is_fsdp_enabled(): # set dtype and device map if not use deepspeed or fsdp
|
||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
|
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
||||||
if init_kwargs["low_cpu_mem_usage"]:
|
|
||||||
if "device_map" not in init_kwargs and model_args.device_map:
|
if "device_map" not in init_kwargs and model_args.device_map:
|
||||||
init_kwargs["device_map"] = model_args.device_map
|
init_kwargs["device_map"] = model_args.device_map
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue