fix layer norm name in PPO
This commit is contained in:
parent
bd565af370
commit
e3aaef7d4a
|
@ -41,7 +41,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
||||||
|
|
||||||
def cast_layernorm_dtype(
|
def cast_layernorm_dtype(
|
||||||
model: AutoModelForCausalLMWithValueHead,
|
model: AutoModelForCausalLMWithValueHead,
|
||||||
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting
|
layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting
|
||||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue