Update loader.py
This commit is contained in:
parent
3d39d74003
commit
a5b809516e
|
@ -12,13 +12,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union, Dict
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
@ -27,7 +27,7 @@ from .aligner import align_dataset
|
||||||
from .data_utils import merge_dataset, split_dataset
|
from .data_utils import merge_dataset, split_dataset
|
||||||
from .parser import get_dataset_list
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .preprocess import get_preprocess_and_print_func
|
||||||
from .template import get_template_and_fix_tokenizer, Template
|
from .template import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -35,13 +35,15 @@ if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
|
from .data_utils import DatasetModule
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_single_dataset(
|
def _load_single_dataset(
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
@ -81,31 +83,24 @@ def load_single_dataset(
|
||||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if dataset_attr.load_from == "ms_hub":
|
||||||
try:
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
dataset = MsDataset.load(
|
dataset = MsDataset.load(
|
||||||
dataset_name=data_path,
|
dataset_name=data_path,
|
||||||
subset_name=data_name,
|
subset_name=data_name,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
)
|
)
|
||||||
if isinstance(dataset, MsDataset):
|
if isinstance(dataset, MsDataset):
|
||||||
dataset = dataset.to_hf_dataset()
|
dataset = dataset.to_hf_dataset()
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
|
||||||
else:
|
else:
|
||||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
|
||||||
kwargs = {"trust_remote_code": True}
|
|
||||||
else:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
name=data_name,
|
name=data_name,
|
||||||
|
@ -115,7 +110,7 @@ def load_single_dataset(
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
**kwargs,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
|
@ -140,90 +135,64 @@ def load_single_dataset(
|
||||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||||
|
|
||||||
|
|
||||||
def load_and_preprocess(
|
def _get_merged_dataset(
|
||||||
|
dataset_names: Optional[Sequence[str]],
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
|
if dataset_names is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
|
||||||
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
|
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||||
|
|
||||||
|
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_preprocessed_dataset(
|
||||||
|
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
is_eval: bool = False
|
is_eval: bool = False,
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
if not is_eval and data_args.tokenized_path is not None:
|
if dataset is None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
return None
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
|
||||||
dataset = load_from_disk(data_args.tokenized_path)
|
|
||||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.to_iterable_dataset()
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
if data_args.streaming:
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
||||||
|
)
|
||||||
if is_eval and data_args.eval_tokenized_path is not None:
|
column_names = list(next(iter(dataset)).keys())
|
||||||
if has_tokenized_data(data_args.eval_tokenized_path):
|
kwargs = {}
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
if not data_args.streaming:
|
||||||
dataset = load_from_disk(data_args.eval_tokenized_path)
|
kwargs = dict(
|
||||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.eval_tokenized_path))
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
if data_args.streaming:
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
dataset = dataset.to_iterable_dataset()
|
desc="Running tokenizer on dataset",
|
||||||
return dataset
|
|
||||||
|
|
||||||
if data_args.streaming:
|
|
||||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="load dataset"):
|
|
||||||
all_datasets = []
|
|
||||||
for dataset_attr in get_dataset_list(data_args, data_args.eval_dataset if is_eval else data_args.dataset):
|
|
||||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
|
||||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
|
||||||
|
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
|
||||||
|
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
|
||||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
|
||||||
data_args, training_args, stage, template, tokenizer, processor
|
|
||||||
)
|
)
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
|
||||||
desc="Running tokenizer on dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||||
|
|
||||||
if not is_eval and data_args.tokenized_path is not None:
|
if training_args.should_log:
|
||||||
if training_args.should_save:
|
try:
|
||||||
dataset.save_to_disk(data_args.tokenized_path)
|
print("eval example:" if is_eval else "training example:")
|
||||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
print_function(next(iter(dataset)))
|
||||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
except StopIteration:
|
||||||
|
if stage == "pt":
|
||||||
|
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||||
|
|
||||||
sys.exit(0)
|
return dataset
|
||||||
if is_eval and data_args.eval_tokenized_path is not None:
|
|
||||||
if training_args.should_save:
|
|
||||||
dataset.save_to_disk(data_args.eval_tokenized_path)
|
|
||||||
logger.info("Tokenized dataset saved at {}.".format(data_args.eval_tokenized_path))
|
|
||||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.eval_tokenized_path))
|
|
||||||
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
if training_args.should_log:
|
|
||||||
try:
|
|
||||||
print_function(next(iter(dataset)))
|
|
||||||
except StopIteration:
|
|
||||||
if stage == "pt":
|
|
||||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
|
@ -232,16 +201,76 @@ def get_dataset(
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> "DatasetModule":
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
train_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor)
|
# Load tokenized dataset
|
||||||
|
if data_args.tokenized_path is not None:
|
||||||
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
|
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||||
|
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||||
|
|
||||||
if data_args.eval_dataset or data_args.eval_tokenized_path:
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
eval_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor, True)
|
if "train" in dataset_dict:
|
||||||
return {"train_dataset": train_dataset, "eval_dataset": eval_dataset}
|
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||||
else:
|
if "validation" in dataset_dict:
|
||||||
return split_dataset(train_dataset, data_args, training_args)
|
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||||
|
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||||
|
|
||||||
|
return dataset_module
|
||||||
|
|
||||||
|
if data_args.streaming:
|
||||||
|
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||||
|
|
||||||
|
# Load and preprocess dataset
|
||||||
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
|
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||||
|
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
|
dataset = _get_preprocessed_dataset(
|
||||||
|
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||||
|
)
|
||||||
|
eval_dataset = _get_preprocessed_dataset(
|
||||||
|
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.val_size > 1e-6:
|
||||||
|
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
|
||||||
|
else:
|
||||||
|
dataset_dict = {}
|
||||||
|
if dataset is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
|
||||||
|
dataset_dict["train"] = dataset
|
||||||
|
|
||||||
|
if eval_dataset is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
|
||||||
|
dataset_dict["validation"] = eval_dataset
|
||||||
|
|
||||||
|
dataset_dict = DatasetDict(dataset_dict)
|
||||||
|
|
||||||
|
if data_args.tokenized_path is not None:
|
||||||
|
if training_args.should_save:
|
||||||
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
|
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||||
|
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||||
|
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
dataset_module = {}
|
||||||
|
if "train" in dataset_dict:
|
||||||
|
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||||
|
if "validation" in dataset_dict:
|
||||||
|
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||||
|
|
||||||
|
return dataset_module
|
||||||
|
|
Loading…
Reference in New Issue