From 0a46313ccaee91b51bec9f9f92e3111a4a04ce2e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 28 Jun 2023 01:55:12 +0800 Subject: [PATCH] fix loading best model --- src/utils/peft_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index 5652aa99..4a22d5af 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -126,12 +126,14 @@ class PeftTrainer(Seq2SeqTrainer): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") model = unwrap_model(self.model) + backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model + if self.finetuning_args.finetuning_type == "lora": - model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter")) + backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter")) if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): model.v_head.load_state_dict({ "summary.weight": getattr(model, "reward_head_weight"), "summary.bias": getattr(model, "reward_head_bias") }) else: # freeze/full-tuning - load_trainable_params(model, self.state.best_model_checkpoint) + load_trainable_params(backbone_model, self.state.best_model_checkpoint)