From 3d39d74003c4ca36f9c9b77f622d366383b0af7e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 14 Jul 2024 23:04:34 +0800 Subject: [PATCH] Update parser.py --- src/llamafactory/data/parser.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index c810ec8b..c443b9d9 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -38,9 +38,9 @@ class DatasetAttr: ranking: bool = False # extra configs subset: Optional[str] = None + split: str = "train" folder: Optional[str] = None num_samples: Optional[int] = None - split: Optional[str] = "train" # common columns system: Optional[str] = None tools: Optional[str] = None @@ -72,7 +72,7 @@ class DatasetAttr: setattr(self, key, obj.get(key, default)) -def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List["DatasetAttr"]: +def get_dataset_list(data_args: "DataArguments", dataset: Optional[str]) -> List["DatasetAttr"]: if dataset is not None: dataset_names = [ds.strip() for ds in dataset.split(",")] else: @@ -121,10 +121,9 @@ def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List[ dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("subset", dataset_info[name]) + dataset_attr.set_attr("split", dataset_info[name], default="train") dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name]) - if "split" in dataset_info[name]: - dataset_attr.set_attr("split", dataset_info[name]) if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]