OpenDeltaMirror/examples/examples_prompt/trainers/trainer_args.py

142 lines
6.9 KiB
Python

from dataclasses import dataclass, field
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments
# run_seq2seq parameters.
@dataclass
class TrainingArguments(Seq2SeqTrainingArguments):
print_num_parameters: Optional[bool] = field(default=False, metadata={"help": "If set, print the parameters of "
"the model."})
do_test: Optional[bool] = field(default=False, metadata={"help": "If set, evaluates the test performance."})
split_validation_test: Optional[bool] = field(default=False,
metadata={"help": "If set, for the datasets which do not"
"have the test set, we use validation set as their"
"test set and make a validation set from either"
"splitting the validation set into half (for smaller"
"than 10K samples datasets), or by using 1K examples"
"from training set as validation set (for larger"
" datasets)."})
compute_time: Optional[bool] = field(default=False, metadata={"help": "If set measures the time."})
compute_memory: Optional[bool] = field(default=False, metadata={"help": "if set, measures the memory"})
is_seq2seq: Optional[bool] = field(default=True, metadata={"help": "whether the pipeline is a seq2seq one"})
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
eval_dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the evaluation dataset to use (via the datasets library)."}
)
eval_dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the evaluation dataset to use (via the datasets library)."}
)
test_dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the test dataset to use (via the datasets library)."}
)
test_dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the test dataset to use (via the datasets library)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
test_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."}
)
num_beams: Optional[int] = field(default=None, metadata={"help": "Number of beams to use for evaluation."})
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
task_adapters: Optional[List[str]] = field(
default=None,
metadata={"help": "Defines a dictionary from task adapters to the tasks."}
)
task_embeddings: Optional[List[str]] = field(
default=None,
metadata={"help": "Defines a dictionary from tasks to the tasks embeddings."}
)
data_seed: Optional[int] = field(default=42, metadata={"help": "seed used to shuffle the data."})
model_parallel: Optional[bool] = field(default=False, metadata={"help": "whether apply model parallelization"})
def __post_init__(self):
if self.task_name is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
if self.test_max_target_length is None:
self.test_max_target_length = self.max_target_length