From 5907216a1cc7a75a43d681ede410c2fba7fb7b92 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 31 Mar 2024 19:43:48 +0800 Subject: [PATCH] fix plots --- src/llmtuner/train/dpo/trainer.py | 5 ++--- src/llmtuner/train/dpo/workflow.py | 2 +- src/llmtuner/train/orpo/workflow.py | 2 +- src/llmtuner/train/rm/workflow.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index c7e385da..7582e16f 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -142,11 +142,10 @@ class CustomDPOTrainer(DPOTrainer): reference_chosen_logps, reference_rejected_logps, ) - batch_loss = losses.mean() if self.ftx_gamma > 1e-6: batch_size = batch["input_ids"].size(0) // 2 chosen_labels, _ = batch["labels"].split(batch_size, dim=0) - batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean() + losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels) reward_accuracies = (chosen_rewards > rejected_rewards).float() @@ -160,4 +159,4 @@ class CustomDPOTrainer(DPOTrainer): metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean() metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean() - return batch_loss, metrics + return losses.mean(), metrics diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 4a1e867e..929dd029 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -63,7 +63,7 @@ def run_dpo( trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"]) # Evaluation if training_args.do_eval: diff --git a/src/llmtuner/train/orpo/workflow.py b/src/llmtuner/train/orpo/workflow.py index 1d549d28..5a2fd36c 100644 --- a/src/llmtuner/train/orpo/workflow.py +++ b/src/llmtuner/train/orpo/workflow.py @@ -56,7 +56,7 @@ def run_orpo( trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies", "sft_loss"]) # Evaluation if training_args.do_eval: diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index f260f82e..42bf1ce6 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -55,7 +55,7 @@ def run_rm( trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) # Evaluation if training_args.do_eval: