fix #1550
This commit is contained in:
parent
999bc0ed93
commit
1bbc1be95e
|
@ -168,12 +168,17 @@ def load_model_and_tokenizer(
|
|||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if is_deepspeed_zero3_enabled() or getattr(config, "model_type", None) == "chatglm":
|
||||
low_cpu_mem_usage = False
|
||||
else:
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
# Load pre-trained models (without valuehead)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
|
|
|
@ -140,7 +140,7 @@ def prepare_model_for_training(
|
|||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue