From b355f6cac99592b66890ccc04e77a9993de0447d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 3 Nov 2023 00:34:26 +0800 Subject: [PATCH] fix bug in data loader, support dpo eval --- src/llmtuner/dsets/loader.py | 4 ++-- src/llmtuner/tuner/dpo/trainer.py | 1 + src/llmtuner/tuner/dpo/workflow.py | 6 ++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 0a4a17f6..834ef733 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -82,8 +82,8 @@ def get_dataset( assistant_role = msg_list[idx + 1][dataset_attr.role] else: if ( - msg_list[idx][dataset_attr.query] != user_role - or msg_list[idx+1][dataset_attr.query] != assistant_role + msg_list[idx][dataset_attr.role] != user_role + or msg_list[idx+1][dataset_attr.role] != assistant_role ): raise ValueError("Only accepts conversation in u/a/u/a/u/a order.") msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content])) diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index bde02327..c2b0b581 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -30,6 +30,7 @@ class CustomDPOTrainer(DPOTrainer): self.is_encoder_decoder = model.config.is_encoder_decoder self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning + self.generate_during_eval = False # disable at evaluation self.label_pad_token_id = IGNORE_INDEX self.padding_value = 0 self.beta = beta diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index 545485c6..6e16dd18 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -58,3 +58,9 @@ def run_dpo( trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics)