forked from p04798526/LLaMA-Factory-Mirror
fix #2067
This commit is contained in:
parent
1696698eb9
commit
368b31f6b7
|
@ -4,9 +4,10 @@ import time
|
|||
from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
from transformers.modeling_utils import custom_object_save, unwrap_model
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
from peft import PeftModel
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
@ -19,14 +20,20 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir)
|
||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||
def _save_model_with_valuehead(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
output_dir: str,
|
||||
safe_serialization: bool
|
||||
) -> None:
|
||||
if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
model.pretrained_model.config.save_pretrained(output_dir)
|
||||
if model.pretrained_model.can_generate():
|
||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||
|
||||
if getattr(model, "is_peft_model", False):
|
||||
model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
@ -38,7 +45,8 @@ class SavePeftModelCallback(TrainerCallback):
|
|||
if args.should_save:
|
||||
_save_model_with_valuehead(
|
||||
model=unwrap_model(kwargs.pop("model")),
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||
safe_serialization=args.save_safetensors
|
||||
)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
|
@ -46,7 +54,11 @@ class SavePeftModelCallback(TrainerCallback):
|
|||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
||||
_save_model_with_valuehead(
|
||||
model=unwrap_model(kwargs.pop("model")),
|
||||
output_dir=args.output_dir,
|
||||
safe_serialization=args.save_safetensors
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
|
Loading…
Reference in New Issue