fix #296
This commit is contained in:
parent
e6a3894b99
commit
e3f80774c4
|
@ -2,7 +2,7 @@ import os
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
|
@ -93,7 +93,11 @@ def get_dataset(
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||||
|
|
||||||
if dataset_attr.source_prefix: # add prefix
|
if dataset_attr.source_prefix: # add prefix
|
||||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
|
features = None
|
||||||
|
if data_args.streaming:
|
||||||
|
features = dataset.features
|
||||||
|
features["prefix"] = Value(dtype="string", id=None)
|
||||||
|
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
||||||
|
|
||||||
all_datasets.append(dataset)
|
all_datasets.append(dataset)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ def preprocess_dataset(
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||||
) -> "Dataset":
|
) -> "Dataset":
|
||||||
column_names = list(dataset.column_names or [])
|
column_names = list(dataset.column_names)
|
||||||
template = get_template(data_args.template)
|
template = get_template(data_args.template)
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
|
@ -143,15 +143,19 @@ def preprocess_dataset(
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
dataset = dataset.filter(lambda example: example["prompt"])
|
dataset = dataset.filter(lambda example: example["prompt"])
|
||||||
preprocess_function = preprocess_pretrain_dataset
|
preprocess_function = preprocess_pretrain_dataset
|
||||||
|
print_function = print_unsupervised_dataset_example
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||||
preprocess_function = preprocess_supervised_dataset
|
preprocess_function = preprocess_supervised_dataset
|
||||||
|
print_function = print_supervised_dataset_example
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||||
preprocess_function = preprocess_pairwise_dataset
|
preprocess_function = preprocess_pairwise_dataset
|
||||||
|
print_function = print_pairwise_dataset_example
|
||||||
else:
|
else:
|
||||||
dataset = dataset.filter(lambda example: example["prompt"])
|
dataset = dataset.filter(lambda example: example["prompt"])
|
||||||
preprocess_function = preprocess_unsupervised_dataset
|
preprocess_function = preprocess_unsupervised_dataset
|
||||||
|
print_function = print_unsupervised_dataset_example
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -172,13 +176,5 @@ def preprocess_dataset(
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
||||||
|
|
||||||
if stage == "pt":
|
print_function(next(iter(dataset)))
|
||||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
|
||||||
elif stage == "sft":
|
|
||||||
print_supervised_dataset_example(next(iter(dataset)))
|
|
||||||
elif stage == "rm":
|
|
||||||
print_pairwise_dataset_example(next(iter(dataset)))
|
|
||||||
elif stage == "ppo":
|
|
||||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
Loading…
Reference in New Issue