fix #4221
This commit is contained in:
parent
9419f96609
commit
6baafd4eb3
|
@ -10,6 +10,7 @@ from .data_utils import Role
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
@ -175,7 +176,10 @@ def convert_sharegpt(
|
||||||
|
|
||||||
|
|
||||||
def align_dataset(
|
def align_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
|
dataset_attr: "DatasetAttr",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
r"""
|
r"""
|
||||||
Aligned dataset:
|
Aligned dataset:
|
||||||
|
@ -208,7 +212,7 @@ def align_dataset(
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
desc="Converting format of dataset",
|
desc="Converting format of dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,7 @@ from .template import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
@ -32,6 +31,7 @@ def load_single_dataset(
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
|
@ -123,7 +123,7 @@ def load_single_dataset(
|
||||||
max_samples = min(data_args.max_samples, len(dataset))
|
max_samples = min(data_args.max_samples, len(dataset))
|
||||||
dataset = dataset.select(range(max_samples))
|
dataset = dataset.select(range(max_samples))
|
||||||
|
|
||||||
return align_dataset(dataset, dataset_attr, data_args)
|
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
|
@ -157,7 +157,8 @@ def get_dataset(
|
||||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||||
|
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
|
@ -169,7 +170,7 @@ def get_dataset(
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
load_from_cache_file=(not data_args.overwrite_cache),
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
desc="Running tokenizer on dataset",
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
|
@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
|
@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
|
@ -6,8 +6,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ...extras.packages import is_rouge_available
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
if is_jieba_available():
|
if is_jieba_available():
|
||||||
|
|
Loading…
Reference in New Issue