diff --git a/src/llmtuner/model/utils/valuehead.py b/src/llmtuner/model/utils/valuehead.py index a192dcfa..a6180753 100644 --- a/src/llmtuner/model/utils/valuehead.py +++ b/src/llmtuner/model/utils/valuehead.py @@ -57,3 +57,7 @@ def prepare_valuehead_model(model: "PreTrainedModel") -> None: if getattr(model.config, "model_type", None) == "chatglm": setattr(model, "lm_head", model.transformer.output_layer) setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if getattr(model.config, "model_type", None) == "internlm2": + setattr(model, "lm_head", model.output) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])