From a851056229f37391023627180b5712ed64ae3528 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sat, 18 May 2024 22:02:42 +0800 Subject: [PATCH] improve data process logger --- src/llamafactory/data/aligner.py | 2 +- src/llamafactory/data/preprocess.py | 5 +++++ src/llamafactory/train/kto/workflow.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 6a74a843..2a382c60 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -149,7 +149,7 @@ def convert_sharegpt( chosen[dataset_attr.role_tag] not in accept_tags[-1] or rejected[dataset_attr.role_tag] not in accept_tags[-1] ): - logger.warning("Invalid role tag in {}.".format(messages)) + logger.warning("Invalid role tag in {}.".format([chosen, rejected])) broken_data = True prompt = aligned_messages diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index a6fb0ddc..557678e6 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -77,6 +77,7 @@ def preprocess_supervised_dataset( for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue if processor is not None: @@ -129,6 +130,7 @@ def preprocess_packed_supervised_dataset( input_ids, labels = [], [] for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue messages = examples["prompt"][i] + examples["response"][i] @@ -178,6 +180,7 @@ def preprocess_unsupervised_dataset( for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue if processor is not None: @@ -224,6 +227,7 @@ def preprocess_pairwise_dataset( for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue if processor is not None: @@ -285,6 +289,7 @@ def preprocess_kto_dataset( for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: + logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue if processor is not None: diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 615fdb62..26dc770c 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -62,7 +62,7 @@ def run_kto( trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"]) # Evaluation if training_args.do_eval: