fix plots

This commit is contained in:
hiyouga 2024-03-31 19:43:48 +08:00
parent 68aaa4904b
commit 5907216a1c
4 changed files with 5 additions and 6 deletions

View File

@ -142,11 +142,10 @@ class CustomDPOTrainer(DPOTrainer):
reference_chosen_logps,
reference_rejected_logps,
)
batch_loss = losses.mean()
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
@ -160,4 +159,4 @@ class CustomDPOTrainer(DPOTrainer):
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
return batch_loss, metrics
return losses.mean(), metrics

View File

@ -63,7 +63,7 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
# Evaluation
if training_args.do_eval:

View File

@ -56,7 +56,7 @@ def run_orpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies", "sft_loss"])
# Evaluation
if training_args.do_eval:

View File

@ -55,7 +55,7 @@ def run_rm(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
# Evaluation
if training_args.do_eval: