diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 99b71cf0..ec97bfc1 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -20,11 +20,12 @@ class DatasetAttr: """ basic configs """ load_from: Literal["hf_hub", "ms_hub", "script", "file"] dataset_name: str + formatting: Literal["alpaca", "sharegpt"] = "alpaca" + ranking: bool = False """ extra configs """ subset: Optional[str] = None folder: Optional[str] = None - ranking: bool = False - formatting: Literal["alpaca", "sharegpt"] = "alpaca" + num_samples: Optional[int] = None """ common columns """ system: Optional[str] = None tools: Optional[str] = None @@ -48,7 +49,6 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" - sample_num: Optional[int] = None def __repr__(self) -> str: return self.dataset_name @@ -103,12 +103,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: else: dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) + dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") + dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name]) - dataset_attr.set_attr("ranking", dataset_info[name], default=False) - dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") - dataset_attr.set_attr("sample_num", dataset_info[name]) - + dataset_attr.set_attr("num_samples", dataset_info[name]) + if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] if dataset_attr.formatting == "alpaca":