From bfdee1608f53a6334d8e73c48dbeb4160969d783 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 14 Dec 2023 20:15:20 +0800 Subject: [PATCH] fix valuehead model --- src/llmtuner/model/loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 249f4734..c961e9b0 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -203,6 +203,9 @@ def load_model_and_tokenizer( # Prepare model with valuehead for RLHF if add_valuehead: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + return self.pretrained_model.get_input_embeddings() + setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method