support dataset cache
This commit is contained in:
parent
838ed9aa87
commit
3fe7df628d
|
@ -1,8 +1,12 @@
|
|||
import os
|
||||
import tiktoken
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||
|
||||
from datasets import load_from_disk
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -12,6 +16,9 @@ if TYPE_CHECKING:
|
|||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
|
@ -19,7 +26,6 @@ def preprocess_dataset(
|
|||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
|
@ -226,7 +232,12 @@ def preprocess_dataset(
|
|||
preprocess_func = preprocess_unsupervised_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
|
||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
return load_from_disk(data_args.cache_path)
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
|
@ -242,10 +253,15 @@ def preprocess_dataset(
|
|||
**kwargs
|
||||
)
|
||||
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.cache_path)
|
||||
raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.")
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise ValueError("Empty dataset!")
|
||||
raise RuntimeError("Empty dataset!")
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -98,6 +98,10 @@ class DataArguments:
|
|||
default=False,
|
||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
|
@ -106,6 +110,9 @@ class DataArguments:
|
|||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
if self.streaming and self.cache_path:
|
||||
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
|
|
Loading…
Reference in New Issue