This commit is contained in:
hiyouga 2024-06-13 02:48:21 +08:00
parent 9419f96609
commit 6baafd4eb3
9 changed files with 19 additions and 19 deletions

View File

@ -10,6 +10,7 @@ from .data_utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .parser import DatasetAttr from .parser import DatasetAttr
@ -175,7 +176,10 @@ def convert_sharegpt(
def align_dataset( def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
@ -208,7 +212,7 @@ def align_dataset(
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset", desc="Converting format of dataset",
) )

View File

@ -18,8 +18,7 @@ from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr from .parser import DatasetAttr
@ -32,6 +31,7 @@ def load_single_dataset(
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
@ -123,7 +123,7 @@ def load_single_dataset(
max_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples)) dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args) return align_dataset(dataset, dataset_attr, data_args, training_args)
def get_dataset( def get_dataset(
@ -157,7 +157,8 @@ def get_dataset(
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.") raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
@ -169,7 +170,7 @@ def get_dataset(
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )

View File

@ -13,8 +13,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments from ..hparams import DataArguments
from .template import Template from .template import Template

View File

@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments

View File

@ -7,8 +7,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@ -9,7 +9,7 @@ from ...extras.packages import is_rouge_available
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers import PreTrainedTokenizer
if is_jieba_available(): if is_jieba_available():