From e3aaef7d4a37e4aa388a9158c382db8239843a5e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 2 Jun 2023 17:30:01 +0800 Subject: [PATCH] fix layer norm name in PPO --- src/utils/ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/ppo.py b/src/utils/ppo.py index b782d1e6..10c80d22 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -41,7 +41,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def def cast_layernorm_dtype( model: AutoModelForCausalLMWithValueHead, - layer_norm_names: List[str] = ["layernorm"], # for chatglm setting + layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting layer_norm_params: Optional[Dict[str, torch.Tensor]] = None ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: