Update data_args.py
This commit is contained in:
parent
df52fb05b1
commit
cba673f491
|
@ -31,12 +31,11 @@ class DataArguments:
|
|||
)
|
||||
dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use for eval during training. "
|
||||
"Use commas to separate multiple datasets."},
|
||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||
)
|
||||
dataset_dir: str = field(
|
||||
default="data",
|
||||
|
@ -110,12 +109,33 @@ class DataArguments:
|
|||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
)
|
||||
eval_tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized eval datasets."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def split_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.dataset = split_arg(self.dataset)
|
||||
self.eval_dataset = split_arg(self.eval_dataset)
|
||||
|
||||
if self.dataset is None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||
|
||||
if self.eval_dataset is not None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
if self.mix_strategy == "concat":
|
||||
raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
|
||||
|
||||
self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
|
||||
if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
|
||||
raise ValueError("The length of dataset and interleave probs should be identical.")
|
||||
|
||||
if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
|
||||
raise ValueError("The length of eval dataset and interleave probs should be identical.")
|
||||
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue