From c13ae2df19ed4cdc849bef55d04225e1a98c19b5 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 2 Jul 2024 22:32:05 +0800 Subject: [PATCH] upcast logits --- src/llamafactory/train/ppo/trainer.py | 2 +- src/llamafactory/train/rm/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 2f9978a5..1c401938 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -407,7 +407,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): values = torch.transpose(values, 0, 1) rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) - return rewards.to(torch.float32).detach().cpu() # use fp32 type + return rewards.float().detach() # use fp32 type @PPODecorators.empty_device_cache() def batched_forward_pass( diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index f7160cfc..267e88e2 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -99,7 +99,7 @@ class PairwiseTrainer(Trainer): chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1)) rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)) chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze() - loss = -torch.nn.functional.logsigmoid(chosen_scores - rejected_scores).mean() + loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() if return_outputs: return loss, (loss, chosen_scores, rejected_scores) else: