Update supervised.py
This commit is contained in:
parent
788e8232fc
commit
c09ad8bab3
|
@ -179,15 +179,16 @@ def preprocess_packed_supervised_dataset(
|
||||||
packed_input_ids += batch_input_ids[index]
|
packed_input_ids += batch_input_ids[index]
|
||||||
packed_labels += batch_labels[index]
|
packed_labels += batch_labels[index]
|
||||||
|
|
||||||
if len(packed_input_ids) <= data_args.cutoff_len:
|
if len(packed_input_ids) < data_args.cutoff_len:
|
||||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||||
packed_labels += [IGNORE_INDEX] * pad_length
|
packed_labels += [IGNORE_INDEX] * pad_length
|
||||||
else:
|
|
||||||
raise ValueError("The length of packed example exceeds the cutoff length.")
|
if len(packed_input_ids) != data_args.cutoff_len:
|
||||||
|
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||||
|
|
||||||
model_inputs["input_ids"].append(packed_input_ids)
|
model_inputs["input_ids"].append(packed_input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(packed_input_ids))
|
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||||
model_inputs["labels"].append(packed_labels)
|
model_inputs["labels"].append(packed_labels)
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
Loading…
Reference in New Issue