From c09ad8bab38bc2f151da3a924eba225111af2481 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 7 Jun 2024 03:42:08 +0800 Subject: [PATCH] Update supervised.py --- src/llamafactory/data/processors/supervised.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 502b591c..a340a1ab 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -179,15 +179,16 @@ def preprocess_packed_supervised_dataset( packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] - if len(packed_input_ids) <= data_args.cutoff_len: + if len(packed_input_ids) < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids) packed_input_ids += [tokenizer.pad_token_id] * pad_length packed_labels += [IGNORE_INDEX] * pad_length - else: - raise ValueError("The length of packed example exceeds the cutoff length.") + + if len(packed_input_ids) != data_args.cutoff_len: + raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append([1] * len(packed_input_ids)) + model_inputs["attention_mask"].append([1] * data_args.cutoff_len) model_inputs["labels"].append(packed_labels) return model_inputs