From 30fb721f1222d1b56a2712519960a63655c20360 Mon Sep 17 00:00:00 2001 From: mmbwf <96833021+mmbwf@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:38:04 +0800 Subject: [PATCH] Update utils.py Fix parameters load error. --- src/llmtuner/tuner/ppo/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/tuner/ppo/utils.py b/src/llmtuner/tuner/ppo/utils.py index 1b48c8c0..7c4ac997 100644 --- a/src/llmtuner/tuner/ppo/utils.py +++ b/src/llmtuner/tuner/ppo/utils.py @@ -10,8 +10,8 @@ if TYPE_CHECKING: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: if target == "reward": # save default head temporarily valuehead_state_dict = model.v_head.state_dict() - setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) - setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"]) + setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].clone()) + setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].clone()) model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active model.v_head.load_state_dict({