This commit is contained in:
hiyouga 2024-01-04 22:53:03 +08:00
parent 1696698eb9
commit 368b31f6b7
1 changed files with 23 additions and 11 deletions

View File

@ -4,9 +4,10 @@ import time
from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import TrainerCallback
from transformers import PreTrainedModel, TrainerCallback
from transformers.modeling_utils import custom_object_save, unwrap_model
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from peft import PeftModel
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
@ -19,14 +20,20 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir)
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
def _save_model_with_valuehead(
model: "AutoModelForCausalLMWithValueHead",
output_dir: str,
safe_serialization: bool
) -> None:
if isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
model.pretrained_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
class SavePeftModelCallback(TrainerCallback):
@ -38,7 +45,8 @@ class SavePeftModelCallback(TrainerCallback):
if args.should_save:
_save_model_with_valuehead(
model=unwrap_model(kwargs.pop("model")),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors
)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
@ -46,7 +54,11 @@ class SavePeftModelCallback(TrainerCallback):
Event called at the end of training.
"""
if args.should_save:
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
_save_model_with_valuehead(
model=unwrap_model(kwargs.pop("model")),
output_dir=args.output_dir,
safe_serialization=args.save_safetensors
)
class LogCallback(TrainerCallback):