From 8632bff81110b202919e27b33294898f16638c9d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 14 Sep 2023 18:37:34 +0800 Subject: [PATCH] fix #896 --- src/llmtuner/dsets/preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 393366e6..c42b2047 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -140,9 +140,9 @@ def preprocess_dataset( print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) - print("labels:\n{}".format(tokenizer.decode([ - token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"] - ], skip_special_tokens=False))) + print("labels:\n{}".format( + tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) + )) def print_pairwise_dataset_example(example): print("prompt_ids:\n{}".format(example["prompt_ids"]))