diff --git a/src/utils/ppo.py b/src/utils/ppo.py index 8a068876..7a69c43b 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -157,8 +157,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): stats = self.step(queries, responses, rewards) - loss_meter.update(stats["ppo/loss/total"]) - reward_meter.update(torch.tensor(rewards).sum().item(), n=len(rewards)) + loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) + reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) if steps_trained == len_dataloader: dataiter = iter(self.dataloader)