This commit is contained in:
hiyouga 2023-08-01 18:43:53 +08:00
parent e6a3894b99
commit e3f80774c4
2 changed files with 12 additions and 12 deletions

View File

@ -2,7 +2,7 @@ import os
import hashlib
from typing import TYPE_CHECKING, List, Optional
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.extras.logging import get_logger
@ -93,7 +93,11 @@ def get_dataset(
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
features = None
if data_args.streaming:
features = dataset.features
features["prefix"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
all_datasets.append(dataset)

View File

@ -18,7 +18,7 @@ def preprocess_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> "Dataset":
column_names = list(dataset.column_names or [])
column_names = list(dataset.column_names)
template = get_template(data_args.template)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
@ -143,15 +143,19 @@ def preprocess_dataset(
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
@ -172,13 +176,5 @@ def preprocess_dataset(
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
if stage == "pt":
print_unsupervised_dataset_example(next(iter(dataset)))
elif stage == "sft":
print_supervised_dataset_example(next(iter(dataset)))
elif stage == "rm":
print_pairwise_dataset_example(next(iter(dataset)))
elif stage == "ppo":
print_unsupervised_dataset_example(next(iter(dataset)))
print_function(next(iter(dataset)))
return dataset