Update data_args.py

This commit is contained in:
hoshi-hiyouga 2024-07-15 00:56:03 +08:00 committed by GitHub
parent df52fb05b1
commit cba673f491
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 27 additions and 7 deletions

View File

@ -31,12 +31,11 @@ class DataArguments:
)
dataset: Optional[str] = field(
default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
)
eval_dataset: Optional[str] = field(
default=None,
metadata={"help": "The name of provided dataset(s) to use for eval during training. "
"Use commas to separate multiple datasets."},
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
)
dataset_dir: str = field(
default="data",
@ -110,12 +109,33 @@ class DataArguments:
default=None,
metadata={"help": "Path to save or load the tokenized datasets."},
)
eval_tokenized_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the tokenized eval datasets."},
)
def __post_init__(self):
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset)
if self.dataset is None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
if self.eval_dataset is not None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
if self.interleave_probs is not None:
if self.mix_strategy == "concat":
raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
raise ValueError("The length of dataset and interleave probs should be identical.")
if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
raise ValueError("The length of eval dataset and interleave probs should be identical.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")