support dataset cache
This commit is contained in:
parent
838ed9aa87
commit
3fe7df628d
|
@ -1,8 +1,12 @@
|
||||||
|
import os
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||||
|
|
||||||
|
from datasets import load_from_disk
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -12,6 +16,9 @@ if TYPE_CHECKING:
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
def preprocess_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
@ -19,7 +26,6 @@ def preprocess_dataset(
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
|
@ -226,7 +232,12 @@ def preprocess_dataset(
|
||||||
preprocess_func = preprocess_unsupervised_dataset
|
preprocess_func = preprocess_unsupervised_dataset
|
||||||
print_function = print_unsupervised_dataset_example
|
print_function = print_unsupervised_dataset_example
|
||||||
|
|
||||||
|
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||||
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
|
return load_from_disk(data_args.cache_path)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
|
@ -242,10 +253,15 @@ def preprocess_dataset(
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||||
|
if training_args.should_save:
|
||||||
|
dataset.save_to_disk(data_args.cache_path)
|
||||||
|
raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.")
|
||||||
|
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
try:
|
try:
|
||||||
print_function(next(iter(dataset)))
|
print_function(next(iter(dataset)))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise ValueError("Empty dataset!")
|
raise RuntimeError("Empty dataset!")
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
|
@ -98,6 +98,10 @@ class DataArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||||
)
|
)
|
||||||
|
cache_path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||||
|
@ -106,6 +110,9 @@ class DataArguments:
|
||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
|
if self.streaming and self.cache_path:
|
||||||
|
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
||||||
|
|
||||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||||
|
|
Loading…
Reference in New Issue