diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 374d03c6..db7702cd 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -11,6 +11,7 @@ class DatasetAttr: dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None + stage: Optional[str] = None def __repr__(self) -> str: return self.dataset_name @@ -113,14 +114,21 @@ class DataArguments: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) if "hf_hub_url" in dataset_info[name]: - dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) + dataset_attr = DatasetAttr( + "hf_hub", + dataset_name=dataset_info[name]["hf_hub_url"], + stage=dataset_info[name].get("stage", None)) elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) + dataset_attr = DatasetAttr( + "script", + dataset_name=dataset_info[name]["script_url"], + stage=dataset_info[name].get("stage", None)) else: dataset_attr = DatasetAttr( "file", dataset_name=dataset_info[name]["file_name"], - dataset_sha1=dataset_info[name].get("file_sha1", None) + dataset_sha1=dataset_info[name].get("file_sha1", None), + stage=dataset_info[name].get("stage", None) ) if "columns" in dataset_info[name]: