diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 434956af..3e9d5c46 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -10,6 +10,7 @@ from .data_utils import Role if TYPE_CHECKING: from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments from .parser import DatasetAttr @@ -175,7 +176,10 @@ def convert_sharegpt( def align_dataset( - dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: r""" Aligned dataset: @@ -208,7 +212,7 @@ def align_dataset( if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Converting format of dataset", ) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 2c236c76..ba426f81 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -18,8 +18,7 @@ from .template import get_template_and_fix_tokenizer if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments, ModelArguments from .parser import DatasetAttr @@ -32,6 +31,7 @@ def load_single_dataset( dataset_attr: "DatasetAttr", model_args: "ModelArguments", data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", ) -> Union["Dataset", "IterableDataset"]: logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None @@ -123,7 +123,7 @@ def load_single_dataset( max_samples = min(data_args.max_samples, len(dataset)) dataset = dataset.select(range(max_samples)) - return align_dataset(dataset, dataset_attr, data_args) + return align_dataset(dataset, dataset_attr, data_args, training_args) def get_dataset( @@ -157,7 +157,8 @@ def get_dataset( if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") - all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) + all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args)) + dataset = merge_dataset(all_datasets, data_args, training_args) with training_args.main_process_first(desc="pre-process dataset"): @@ -169,7 +170,7 @@ def get_dataset( if not data_args.streaming: kwargs = dict( num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache), + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Running tokenizer on dataset", ) diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 97789c39..875f55d6 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -13,8 +13,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: - from transformers import ProcessorMixin, Seq2SeqTrainingArguments - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments from .template import Template diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 98d83658..5fba452c 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index fe984efa..db52c6a7 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template diff --git a/src/llamafactory/data/processors/pretrain.py b/src/llamafactory/data/processors/pretrain.py index 832c987e..a10ccabd 100644 --- a/src/llamafactory/data/processors/pretrain.py +++ b/src/llamafactory/data/processors/pretrain.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer from ...hparams import DataArguments diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 19d60280..f59f5371 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -7,8 +7,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index f711eeac..38497a15 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: - from transformers import ProcessorMixin - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments from ..template import Template diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 6ed356c1..923238d6 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -9,7 +9,7 @@ from ...extras.packages import is_rouge_available if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer + from transformers import PreTrainedTokenizer if is_jieba_available():