Update supervised.py

This commit is contained in:
hoshi-hiyouga 2024-06-07 03:42:08 +08:00 committed by GitHub
parent 788e8232fc
commit c09ad8bab3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 4 deletions

View File

@ -179,15 +179,16 @@ def preprocess_packed_supervised_dataset(
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index] packed_labels += batch_labels[index]
if len(packed_input_ids) <= data_args.cutoff_len: if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids) pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length packed_labels += [IGNORE_INDEX] * pad_length
else:
raise ValueError("The length of packed example exceeds the cutoff length.") if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * len(packed_input_ids)) model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)
return model_inputs return model_inputs