fix #296
This commit is contained in:
parent
e6a3894b99
commit
e3f80774c4
|
@ -2,7 +2,7 @@ import os
|
|||
import hashlib
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
@ -93,7 +93,11 @@ def get_dataset(
|
|||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.source_prefix: # add prefix
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
|
||||
features = None
|
||||
if data_args.streaming:
|
||||
features = dataset.features
|
||||
features["prefix"] = Value(dtype="string", id=None)
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ def preprocess_dataset(
|
|||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> "Dataset":
|
||||
column_names = list(dataset.column_names or [])
|
||||
column_names = list(dataset.column_names)
|
||||
template = get_template(data_args.template)
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
|
@ -143,15 +143,19 @@ def preprocess_dataset(
|
|||
if stage == "pt":
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
print_function = print_supervised_dataset_example
|
||||
elif stage == "rm":
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
print_function = print_pairwise_dataset_example
|
||||
else:
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
kwargs = {}
|
||||
|
@ -172,13 +176,5 @@ def preprocess_dataset(
|
|||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
||||
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(next(iter(dataset)))
|
||||
elif stage == "ppo":
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
|
||||
print_function(next(iter(dataset)))
|
||||
return dataset
|
||||
|
|
Loading…
Reference in New Issue