diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 8d0bfe59..a51b9024 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -2,7 +2,7 @@ import os import hashlib from typing import TYPE_CHECKING, List, Optional -from datasets import concatenate_datasets, interleave_datasets, load_dataset +from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset from llmtuner.extras.logging import get_logger @@ -93,7 +93,11 @@ def get_dataset( dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) if dataset_attr.source_prefix: # add prefix - dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}) + features = None + if data_args.streaming: + features = dataset.features + features["prefix"] = Value(dtype="string", id=None) + dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features) all_datasets.append(dataset) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 93c854e0..10c76f2b 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -18,7 +18,7 @@ def preprocess_dataset( training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"] ) -> "Dataset": - column_names = list(dataset.column_names or []) + column_names = list(dataset.column_names) template = get_template(data_args.template) def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: @@ -143,15 +143,19 @@ def preprocess_dataset( if stage == "pt": dataset = dataset.filter(lambda example: example["prompt"]) preprocess_function = preprocess_pretrain_dataset + print_function = print_unsupervised_dataset_example elif stage == "sft" and not training_args.predict_with_generate: dataset = dataset.filter(lambda example: example["prompt"] and example["response"]) preprocess_function = preprocess_supervised_dataset + print_function = print_supervised_dataset_example elif stage == "rm": dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1) preprocess_function = preprocess_pairwise_dataset + print_function = print_pairwise_dataset_example else: dataset = dataset.filter(lambda example: example["prompt"]) preprocess_function = preprocess_unsupervised_dataset + print_function = print_unsupervised_dataset_example with training_args.main_process_first(desc="dataset map pre-processing"): kwargs = {} @@ -172,13 +176,5 @@ def preprocess_dataset( if data_args.streaming: dataset = dataset.shuffle(buffer_size=data_args.buffer_size) - if stage == "pt": - print_unsupervised_dataset_example(next(iter(dataset))) - elif stage == "sft": - print_supervised_dataset_example(next(iter(dataset))) - elif stage == "rm": - print_pairwise_dataset_example(next(iter(dataset))) - elif stage == "ppo": - print_unsupervised_dataset_example(next(iter(dataset))) - + print_function(next(iter(dataset))) return dataset