diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 7398d424..1d9e0051 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger if TYPE_CHECKING: from transformers import TrainingArguments, TrainerState, TrainerControl + from trl import AutoModelForCausalLMWithValueHead logger = get_logger(__name__) @@ -25,16 +26,22 @@ class SavePeftModelCallback(TrainerCallback): """ if args.should_save: output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) - model = kwargs.pop("model") + model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model") + model.pretrained_model.config.save_pretrained(output_dir) + if model.pretrained_model.can_generate(): + model.pretrained_model.generation_config.save_pretrained(output_dir) if getattr(model, "is_peft_model", False): - getattr(model, "pretrained_model").save_pretrained(output_dir) + model.pretrained_model.save_pretrained(output_dir) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of training. """ if args.should_save: - model = kwargs.pop("model") + model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model") + model.pretrained_model.config.save_pretrained(args.output_dir) + if model.pretrained_model.can_generate(): + model.pretrained_model.generation_config.save_pretrained(args.output_dir) if getattr(model, "is_peft_model", False): getattr(model, "pretrained_model").save_pretrained(args.output_dir) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index ca304761..cfdc8b24 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -166,7 +166,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): if self.stage == "ppo" and self.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") - if self.reward_model_type == "lora" and self.finetuning_type != "lora": + if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": raise ValueError("Lora reward model only supports lora training.") def save_to_json(self, json_path: str): diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index f9d37c03..df43599c 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -203,12 +203,13 @@ def load_model_and_tokenizer( if stage in ["rm", "ppo"]: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) reset_logging() - if model_args.checkpoint_dir is not None: # load valuehead weights if exists - logger.warning("Only the last checkpoint containing valuehead will be loaded.") - vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args) - if vhead_params is not None: - model.load_state_dict(vhead_params, strict=False) - logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1])) + vhead_path = ( + model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path + ) + vhead_params = load_valuehead_params(vhead_path, model_args) + if vhead_params is not None: + model.load_state_dict(vhead_params, strict=False) + logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) # Prepare model for inference if not is_trainable: diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index b97037a1..949e2ce8 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -45,6 +45,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.args = training_args self.model_args = model_args self.finetuning_args = finetuning_args + self.reward_model = reward_model self.generation_config = GenerationConfig( pad_token_id=self.tokenizer.pad_token_id, @@ -72,8 +73,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.reward_model = self._prepare_deepspeed(self.reward_model) else: self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) - else: - self.reward_model = None def ppo_train(self) -> None: r"""