diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index c3af364c..18f01db1 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -1,8 +1,12 @@ +import os import tiktoken -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union from itertools import chain +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union + +from datasets import load_from_disk from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.logging import get_logger from llmtuner.extras.template import get_template_and_fix_tokenizer if TYPE_CHECKING: @@ -12,6 +16,9 @@ if TYPE_CHECKING: from llmtuner.hparams import DataArguments +logger = get_logger(__name__) + + def preprocess_dataset( dataset: Union["Dataset", "IterableDataset"], tokenizer: "PreTrainedTokenizer", @@ -19,7 +26,6 @@ def preprocess_dataset( training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"] ) -> Union["Dataset", "IterableDataset"]: - column_names = list(next(iter(dataset)).keys()) template = get_template_and_fix_tokenizer(data_args.template, tokenizer) if data_args.train_on_prompt and template.efficient_eos: @@ -226,7 +232,12 @@ def preprocess_dataset( preprocess_func = preprocess_unsupervised_dataset print_function = print_unsupervised_dataset_example + if data_args.cache_path is not None and os.path.exists(data_args.cache_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + return load_from_disk(data_args.cache_path) + with training_args.main_process_first(desc="dataset map pre-processing"): + column_names = list(next(iter(dataset)).keys()) kwargs = {} if not data_args.streaming: kwargs = dict( @@ -242,10 +253,15 @@ def preprocess_dataset( **kwargs ) + if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): + if training_args.should_save: + dataset.save_to_disk(data_args.cache_path) + raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.") + if training_args.should_log: try: print_function(next(iter(dataset))) except StopIteration: - raise ValueError("Empty dataset!") + raise RuntimeError("Empty dataset!") return dataset diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index ff246c00..184cc3ca 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -98,6 +98,10 @@ class DataArguments: default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} ) + cache_path: Optional[str] = field( + default=None, + metadata={"help": "Path to save or load the preprocessed datasets."} + ) def __post_init__(self): if self.streaming and self.val_size > 1e-6 and self.val_size < 1: @@ -106,6 +110,9 @@ class DataArguments: if self.streaming and self.max_samples is not None: raise ValueError("`max_samples` is incompatible with `streaming`.") + if self.streaming and self.cache_path: + raise ValueError("`cache_path` is incompatible with `streaming`.") + def init_for_training(self, seed: int): # support mixing multiple datasets self.seed = seed dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []