fix ppo trainer #551

This commit is contained in:
hiyouga 2023-08-20 14:07:11 +08:00
parent 290be836b7
commit 0676497104
2 changed files with 27 additions and 2 deletions

View File

@ -12,7 +12,7 @@ from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import replace_model
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
@ -152,8 +152,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273

View File

@ -1,4 +1,7 @@
from typing import TYPE_CHECKING, Literal
import torch
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
@ -15,3 +18,23 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
"summary.weight": getattr(model, "{}_head_weight".format(target)),
"summary.bias": getattr(model, "{}_head_bias".format(target))
})
def cast_layernorm_dtype(
model: "AutoModelForCausalLMWithValueHead",
compute_dtype: torch.dtype,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
layer_norm_state_dict = {}
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
if layer_norm_params is None:
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
param.data = param.data.to(compute_dtype)
else:
param.data = layer_norm_params[name] # restore float32 weights
return model, layer_norm_state_dict