remove filter in preprocess

This commit is contained in:
hiyouga 2023-10-23 23:46:02 +08:00
parent 7de7174ce3
commit 2caf91f824
1 changed files with 18 additions and 10 deletions

View File

@ -62,8 +62,10 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
input_ids, labels = [], []
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
@ -106,6 +108,9 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], []
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
@ -139,6 +144,9 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and query != ""):
continue
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
if template.efficient_eos:
@ -158,7 +166,10 @@ def preprocess_dataset(
def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples):
for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
@ -203,19 +214,15 @@ def preprocess_dataset(
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
@ -235,9 +242,10 @@ def preprocess_dataset(
**kwargs
)
try:
print_function(next(iter(dataset)))
except StopIteration:
raise ValueError("Empty dataset!")
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise ValueError("Empty dataset!")
return dataset