diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 9358adf0..cd34c531 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -12,7 +12,7 @@ from trl.core import LengthSampler, PPODecorators, logprobs_from_logits from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.tuner.core.trainer import PeftTrainer -from llmtuner.tuner.ppo.utils import replace_model +from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments @@ -152,8 +152,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): if length_sampler is not None: generation_kwargs["max_new_tokens"] = length_sampler() + self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs) + self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 diff --git a/src/llmtuner/tuner/ppo/utils.py b/src/llmtuner/tuner/ppo/utils.py index 3215ee8a..1b48c8c0 100644 --- a/src/llmtuner/tuner/ppo/utils.py +++ b/src/llmtuner/tuner/ppo/utils.py @@ -1,4 +1,7 @@ -from typing import TYPE_CHECKING, Literal +import torch +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple + +from llmtuner.extras.constants import LAYERNORM_NAMES if TYPE_CHECKING: from trl import AutoModelForCausalLMWithValueHead @@ -15,3 +18,23 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d "summary.weight": getattr(model, "{}_head_weight".format(target)), "summary.bias": getattr(model, "{}_head_bias".format(target)) }) + + +def cast_layernorm_dtype( + model: "AutoModelForCausalLMWithValueHead", + compute_dtype: torch.dtype, + layer_norm_params: Optional[Dict[str, torch.Tensor]] = None, + layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES +) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]: + + layer_norm_state_dict = {} + + for name, param in model.named_parameters(): + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): + if layer_norm_params is None: + layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability + param.data = param.data.to(compute_dtype) + else: + param.data = layer_norm_params[name] # restore float32 weights + + return model, layer_norm_state_dict