support dataset cache

This commit is contained in:
hiyouga 2023-10-26 21:48:45 +08:00
parent 838ed9aa87
commit 3fe7df628d
2 changed files with 26 additions and 3 deletions

View File

@ -1,8 +1,12 @@
import os
import tiktoken import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from datasets import load_from_disk
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -12,6 +16,9 @@ if TYPE_CHECKING:
from llmtuner.hparams import DataArguments from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def preprocess_dataset( def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@ -19,7 +26,6 @@ def preprocess_dataset(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer) template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
@ -226,7 +232,12 @@ def preprocess_dataset(
preprocess_func = preprocess_unsupervised_dataset preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example print_function = print_unsupervised_dataset_example
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
return load_from_disk(data_args.cache_path)
with training_args.main_process_first(desc="dataset map pre-processing"): with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
@ -242,10 +253,15 @@ def preprocess_dataset(
**kwargs **kwargs
) )
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.")
if training_args.should_log: if training_args.should_log:
try: try:
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
except StopIteration: except StopIteration:
raise ValueError("Empty dataset!") raise RuntimeError("Empty dataset!")
return dataset return dataset

View File

@ -98,6 +98,10 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
) )
cache_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the preprocessed datasets."}
)
def __post_init__(self): def __post_init__(self):
if self.streaming and self.val_size > 1e-6 and self.val_size < 1: if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
@ -106,6 +110,9 @@ class DataArguments:
if self.streaming and self.max_samples is not None: if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.") raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.streaming and self.cache_path:
raise ValueError("`cache_path` is incompatible with `streaming`.")
def init_for_training(self, seed: int): # support mixing multiple datasets def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []