diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index d527d7d2..069ea199 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import os import sys -from typing import TYPE_CHECKING, Literal, Optional, Union, Dict +from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np -from datasets import load_dataset, load_from_disk +from datasets import DatasetDict, load_dataset, load_from_disk +from transformers.utils.versions import require_version from ..extras.constants import FILEEXT2TYPE from ..extras.logging import get_logger @@ -27,7 +27,7 @@ from .aligner import align_dataset from .data_utils import merge_dataset, split_dataset from .parser import get_dataset_list from .preprocess import get_preprocess_and_print_func -from .template import get_template_and_fix_tokenizer, Template +from .template import get_template_and_fix_tokenizer if TYPE_CHECKING: @@ -35,13 +35,15 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments, ModelArguments + from .data_utils import DatasetModule from .parser import DatasetAttr + from .template import Template logger = get_logger(__name__) -def load_single_dataset( +def _load_single_dataset( dataset_attr: "DatasetAttr", model_args: "ModelArguments", data_args: "DataArguments", @@ -81,31 +83,24 @@ def load_single_dataset( raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) if dataset_attr.load_from == "ms_hub": - try: - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE + require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + from modelscope import MsDataset + from modelscope.utils.config_ds import MS_DATASETS_CACHE - cache_dir = model_args.cache_dir or MS_DATASETS_CACHE - dataset = MsDataset.load( - dataset_name=data_path, - subset_name=data_name, - data_dir=data_dir, - data_files=data_files, - split=dataset_attr.split, - cache_dir=cache_dir, - token=model_args.ms_hub_token, - use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), - ) - if isinstance(dataset, MsDataset): - dataset = dataset.to_hf_dataset() - except ImportError: - raise ImportError("Please install modelscope via `pip install modelscope -U`") + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + ) + if isinstance(dataset, MsDataset): + dataset = dataset.to_hf_dataset() else: - if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 - kwargs = {"trust_remote_code": True} - else: - kwargs = {} - dataset = load_dataset( path=data_path, name=data_name, @@ -115,7 +110,7 @@ def load_single_dataset( cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, streaming=(data_args.streaming and (dataset_attr.load_from != "file")), - **kwargs, + trust_remote_code=True, ) if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True @@ -140,90 +135,64 @@ def load_single_dataset( return align_dataset(dataset, dataset_attr, data_args, training_args) -def load_and_preprocess( +def _get_merged_dataset( + dataset_names: Optional[Sequence[str]], model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], - tokenizer: "PreTrainedTokenizer", +) -> Optional[Union["Dataset", "IterableDataset"]]: + if dataset_names is None: + return None + + datasets = [] + for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): + raise ValueError("The dataset is not applicable in the current training stage.") + + datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) + + return merge_dataset(datasets, data_args, seed=training_args.seed) + + +def _get_preprocessed_dataset( + dataset: Optional[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", + tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, - is_eval: bool = False -) -> Union["Dataset", "IterableDataset"]: - if not is_eval and data_args.tokenized_path is not None: - if has_tokenized_data(data_args.tokenized_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.tokenized_path) - logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset + is_eval: bool = False, +) -> Optional[Union["Dataset", "IterableDataset"]]: + if dataset is None: + return None - if data_args.streaming: - raise ValueError("Turn off `streaming` when saving dataset to disk.") - - if is_eval and data_args.eval_tokenized_path is not None: - if has_tokenized_data(data_args.eval_tokenized_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.eval_tokenized_path) - logger.info("Loaded tokenized dataset from {}.".format(data_args.eval_tokenized_path)) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset - - if data_args.streaming: - raise ValueError("Turn off `streaming` when saving dataset to disk.") - - with training_args.main_process_first(desc="load dataset"): - all_datasets = [] - for dataset_attr in get_dataset_list(data_args, data_args.eval_dataset if is_eval else data_args.dataset): - if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): - raise ValueError("The dataset is not applicable in the current training stage.") - - all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args)) - - dataset = merge_dataset(all_datasets, data_args, training_args) - - with training_args.main_process_first(desc="pre-process dataset"): - preprocess_func, print_function = get_preprocess_and_print_func( - data_args, training_args, stage, template, tokenizer, processor + preprocess_func, print_function = get_preprocess_and_print_func( + data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) + ) + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Running tokenizer on dataset", ) - column_names = list(next(iter(dataset)).keys()) - kwargs = {} - if not data_args.streaming: - kwargs = dict( - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), - desc="Running tokenizer on dataset", - ) - dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) + dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) - if not is_eval and data_args.tokenized_path is not None: - if training_args.should_save: - dataset.save_to_disk(data_args.tokenized_path) - logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) - logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) + if training_args.should_log: + try: + print("eval example:" if is_eval else "training example:") + print_function(next(iter(dataset))) + except StopIteration: + if stage == "pt": + raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") + else: + raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") - sys.exit(0) - if is_eval and data_args.eval_tokenized_path is not None: - if training_args.should_save: - dataset.save_to_disk(data_args.eval_tokenized_path) - logger.info("Tokenized dataset saved at {}.".format(data_args.eval_tokenized_path)) - logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.eval_tokenized_path)) - - sys.exit(0) - - if training_args.should_log: - try: - print_function(next(iter(dataset))) - except StopIteration: - if stage == "pt": - raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") - else: - raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") - - return dataset + return dataset def get_dataset( @@ -232,16 +201,76 @@ def get_dataset( training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"] = None -) -> Dict[str, "Dataset"]: + processor: Optional["ProcessorMixin"] = None, +) -> "DatasetModule": template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") - train_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor) + # Load tokenized dataset + if data_args.tokenized_path is not None: + if has_tokenized_data(data_args.tokenized_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) + logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) - if data_args.eval_dataset or data_args.eval_tokenized_path: - eval_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor, True) - return {"train_dataset": train_dataset, "eval_dataset": eval_dataset} - else: - return split_dataset(train_dataset, data_args, training_args) + dataset_module: Dict[str, "Dataset"] = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + if data_args.streaming: + dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + + return dataset_module + + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + + # Load and preprocess dataset + with training_args.main_process_first(desc="load dataset"): + dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + + with training_args.main_process_first(desc="pre-process dataset"): + dataset = _get_preprocessed_dataset( + dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False + ) + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + + if data_args.val_size > 1e-6: + dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) + else: + dataset_dict = {} + if dataset is not None: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["train"] = dataset + + if eval_dataset is not None: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["validation"] = eval_dataset + + dataset_dict = DatasetDict(dataset_dict) + + if data_args.tokenized_path is not None: + if training_args.should_save: + dataset_dict.save_to_disk(data_args.tokenized_path) + logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) + logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) + + sys.exit(0) + + dataset_module = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + return dataset_module