From 166c837b95d42513f7b977b189822b5c7980606d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 28 May 2023 21:48:33 +0800 Subject: [PATCH] tiny fix --- src/utils/ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/ppo.py b/src/utils/ppo.py index 8a068876..7a69c43b 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -157,8 +157,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): stats = self.step(queries, responses, rewards) - loss_meter.update(stats["ppo/loss/total"]) - reward_meter.update(torch.tensor(rewards).sum().item(), n=len(rewards)) + loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) + reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) if steps_trained == len_dataloader: dataiter = iter(self.dataloader)