fix loading best model
This commit is contained in:
parent
7826a8ca77
commit
0a46313cca
|
@ -126,12 +126,14 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
||||
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
|
||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
|
||||
|
|
Loading…
Reference in New Issue