fix #4221
This commit is contained in:
parent
9419f96609
commit
6baafd4eb3
|
@ -10,6 +10,7 @@ from .data_utils import Role
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .parser import DatasetAttr
|
||||
|
@ -175,7 +176,10 @@ def convert_sharegpt(
|
|||
|
||||
|
||||
def align_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Aligned dataset:
|
||||
|
@ -208,7 +212,7 @@ def align_dataset(
|
|||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
|
|
|
@ -18,8 +18,7 @@ from .template import get_template_and_fix_tokenizer
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .parser import DatasetAttr
|
||||
|
@ -32,6 +31,7 @@ def load_single_dataset(
|
|||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
|
@ -123,7 +123,7 @@ def load_single_dataset(
|
|||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(max_samples))
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
|
@ -157,7 +157,8 @@ def get_dataset(
|
|||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
|
@ -169,7 +170,7 @@ def get_dataset(
|
|||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
|
|
|
@ -13,8 +13,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
|
|
|
@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
|
|
@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from ...hparams import DataArguments
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
|
|
@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..template import Template
|
||||
|
|
|
@ -9,7 +9,7 @@ from ...extras.packages import is_rouge_available
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
if is_jieba_available():
|
||||
|
|
Loading…
Reference in New Issue