diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 7f7e62cd..f483099d 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -31,12 +31,11 @@ class DataArguments: ) dataset: Optional[str] = field( default=None, - metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, + metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, ) eval_dataset: Optional[str] = field( default=None, - metadata={"help": "The name of provided dataset(s) to use for eval during training. " - "Use commas to separate multiple datasets."}, + metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, ) dataset_dir: str = field( default="data", @@ -110,12 +109,33 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) - eval_tokenized_path: Optional[str] = field( - default=None, - metadata={"help": "Path to save or load the tokenized eval datasets."}, - ) def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.dataset = split_arg(self.dataset) + self.eval_dataset = split_arg(self.eval_dataset) + + if self.dataset is None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `dataset` is None.") + + if self.eval_dataset is not None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") + + if self.interleave_probs is not None: + if self.mix_strategy == "concat": + raise ValueError("`interleave_probs` is only valid for interleaved mixing.") + + self.interleave_probs = list(map(float, split_arg(self.interleave_probs))) + if self.dataset is not None and len(self.dataset) != len(self.interleave_probs): + raise ValueError("The length of dataset and interleave probs should be identical.") + + if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs): + raise ValueError("The length of eval dataset and interleave probs should be identical.") + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: raise ValueError("Streaming mode should have an integer val size.")