update patcher
This commit is contained in:
parent
42e69a3c63
commit
3b040e8e0f
|
@ -78,9 +78,7 @@ def _fp32_forward_post_hook(
|
|||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||
) -> None:
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
|
@ -104,8 +102,8 @@ def prepare_model_for_training(
|
|||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
|
|
@ -152,6 +152,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
|||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_output_embeddings()
|
||||
|
||||
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
if isinstance(self.pretrained_model, PeftModel):
|
||||
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||
|
@ -160,4 +164,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
|||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
|
||||
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||
|
|
|
@ -70,5 +70,5 @@ def test_upcast_lmhead_output():
|
|||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
|
||||
outputs: "torch.Tensor" = model.lm_head(inputs)
|
||||
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
|
||||
assert outputs.dtype == torch.float32
|
||||
|
|
Loading…
Reference in New Issue