142 lines
6.9 KiB
Python
142 lines
6.9 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Optional, List
|
|
from transformers import Seq2SeqTrainingArguments
|
|
# run_seq2seq parameters.
|
|
|
|
@dataclass
|
|
class TrainingArguments(Seq2SeqTrainingArguments):
|
|
print_num_parameters: Optional[bool] = field(default=False, metadata={"help": "If set, print the parameters of "
|
|
"the model."})
|
|
do_test: Optional[bool] = field(default=False, metadata={"help": "If set, evaluates the test performance."})
|
|
split_validation_test: Optional[bool] = field(default=False,
|
|
metadata={"help": "If set, for the datasets which do not"
|
|
"have the test set, we use validation set as their"
|
|
"test set and make a validation set from either"
|
|
"splitting the validation set into half (for smaller"
|
|
"than 10K samples datasets), or by using 1K examples"
|
|
"from training set as validation set (for larger"
|
|
" datasets)."})
|
|
compute_time: Optional[bool] = field(default=False, metadata={"help": "If set measures the time."})
|
|
compute_memory: Optional[bool] = field(default=False, metadata={"help": "if set, measures the memory"})
|
|
is_seq2seq: Optional[bool] = field(default=True, metadata={"help": "whether the pipeline is a seq2seq one"})
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
class DataTrainingArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
"""
|
|
task_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
|
)
|
|
dataset_config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
|
)
|
|
eval_dataset_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The name of the evaluation dataset to use (via the datasets library)."}
|
|
)
|
|
eval_dataset_config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The configuration name of the evaluation dataset to use (via the datasets library)."}
|
|
)
|
|
test_dataset_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The name of the test dataset to use (via the datasets library)."}
|
|
)
|
|
test_dataset_config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The configuration name of the test dataset to use (via the datasets library)."}
|
|
)
|
|
overwrite_cache: bool = field(
|
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
|
)
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
|
)
|
|
max_source_length: Optional[int] = field(
|
|
default=128,
|
|
metadata={
|
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
max_target_length: Optional[int] = field(
|
|
default=128,
|
|
metadata={
|
|
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
val_max_target_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
|
"during ``evaluate`` and ``predict``."
|
|
},
|
|
)
|
|
test_max_target_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
|
"during ``evaluate`` and ``predict``."
|
|
},
|
|
)
|
|
pad_to_max_length: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether to pad all samples to model maximum sentence length. "
|
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
|
"efficient on GPU but very bad for TPU."
|
|
},
|
|
)
|
|
max_train_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
max_val_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
max_test_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
|
"value if set."}
|
|
)
|
|
num_beams: Optional[int] = field(default=None, metadata={"help": "Number of beams to use for evaluation."})
|
|
ignore_pad_token_for_loss: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
|
},
|
|
)
|
|
task_adapters: Optional[List[str]] = field(
|
|
default=None,
|
|
metadata={"help": "Defines a dictionary from task adapters to the tasks."}
|
|
)
|
|
task_embeddings: Optional[List[str]] = field(
|
|
default=None,
|
|
metadata={"help": "Defines a dictionary from tasks to the tasks embeddings."}
|
|
)
|
|
data_seed: Optional[int] = field(default=42, metadata={"help": "seed used to shuffle the data."})
|
|
|
|
model_parallel: Optional[bool] = field(default=False, metadata={"help": "whether apply model parallelization"})
|
|
|
|
def __post_init__(self):
|
|
if self.task_name is None:
|
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
|
if self.val_max_target_length is None:
|
|
self.val_max_target_length = self.max_target_length
|
|
if self.test_max_target_length is None:
|
|
self.test_max_target_length = self.max_target_length
|