fix layer norm name in PPO

This commit is contained in:
hiyouga 2023-06-02 17:30:01 +08:00
parent bd565af370
commit e3aaef7d4a
1 changed files with 1 additions and 1 deletions

View File

@ -41,7 +41,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead,
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting
layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: