From 3b040e8e0f78dbb6bc1409a1b2b788e1affc7458 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 19 Jun 2024 21:27:00 +0800 Subject: [PATCH] update patcher --- src/llamafactory/model/model_utils/checkpointing.py | 10 ++++------ src/llamafactory/model/patcher.py | 5 +++++ tests/model/model_utils/test_checkpointing.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index f5314125..f4f3d8a5 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -78,9 +78,7 @@ def _fp32_forward_post_hook( return output.to(torch.float32) -def prepare_model_for_training( - model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" -) -> None: +def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: r""" Includes: (1) cast the layernorm in fp32 @@ -104,8 +102,8 @@ def prepare_model_for_training( setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.") - if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: - logger.info("Upcasting lm_head outputs in float32.") - output_layer = getattr(model, output_layer_name) + if model_args.upcast_lmhead_output: + output_layer = model.get_output_embeddings() if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: + logger.info("Upcasting lm_head outputs in float32.") output_layer.register_forward_hook(_fp32_forward_post_hook) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 8fa17d08..a53fde98 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -152,6 +152,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: if isinstance(self.pretrained_model, PreTrainedModel): return self.pretrained_model.get_input_embeddings() + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_output_embeddings() + def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: if isinstance(self.pretrained_model, PeftModel): self.pretrained_model.create_or_update_model_card(output_dir) @@ -160,4 +164,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) + setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model)) diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py index 670e693d..9b6dfc9e 100644 --- a/tests/model/model_utils/test_checkpointing.py +++ b/tests/model/model_utils/test_checkpointing.py @@ -70,5 +70,5 @@ def test_upcast_lmhead_output(): tokenizer_module = load_tokenizer(model_args) model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device()) - outputs: "torch.Tensor" = model.lm_head(inputs) + outputs: "torch.Tensor" = model.get_output_embeddings()(inputs) assert outputs.dtype == torch.float32