improve data process logger

This commit is contained in:
hiyouga 2024-05-18 22:02:42 +08:00
parent ca48f90f1e
commit a851056229
3 changed files with 7 additions and 2 deletions

View File

@ -149,7 +149,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format(messages))
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
broken_data = True
prompt = aligned_messages

View File

@ -77,6 +77,7 @@ def preprocess_supervised_dataset(
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:
@ -129,6 +130,7 @@ def preprocess_packed_supervised_dataset(
input_ids, labels = [], []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
messages = examples["prompt"][i] + examples["response"][i]
@ -178,6 +180,7 @@ def preprocess_unsupervised_dataset(
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:
@ -224,6 +227,7 @@ def preprocess_pairwise_dataset(
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:
@ -285,6 +289,7 @@ def preprocess_kto_dataset(
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
if processor is not None:

View File

@ -62,7 +62,7 @@ def run_kto(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"])
# Evaluation
if training_args.do_eval: