diff --git a/data/README.md b/data/README.md index 9010fb64..15d108da 100644 --- a/data/README.md +++ b/data/README.md @@ -4,9 +4,10 @@ If you are using a custom dataset, please provide your dataset definition in the "dataset_name": { "hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", - "file_name": "the name of the dataset file in the this directory. (required if above are not specified)", + "file_name": "the name of the dataset file in this directory. (required if above are not specified)", "file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)", "subset": "the name of the subset. (optional, default: None)", + "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", "ranking": "whether the dataset is a preference dataset or not. (default: false)", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "columns": { diff --git a/data/README_zh.md b/data/README_zh.md index 740e27db..a6790f70 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -2,11 +2,12 @@ ```json "数据集名称": { - "hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)", + "hf_hub_url": "Hugging Face 的仓库地址(若指定,则忽略下列三个参数)", "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)", "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", - "file_sha1": "数据集文件的SHA-1哈希值(可选,留空不影响训练)", + "file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)", "subset": "数据集子集的名称(可选,默认:None)", + "folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)", "ranking": "是否为偏好数据集(可选,默认:False)", "formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)", "columns": { diff --git a/data/dataset_info.json b/data/dataset_info.json index 2b3f4eb7..1896d94d 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -274,10 +274,11 @@ "prompt": "content" } }, - "starcoder": { + "starcoder_python": { "hf_hub_url": "bigcode/starcoderdata", "columns": { "prompt": "content" - } + }, + "folder": "python" } } diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 8e9053ca..d5a7a588 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -24,27 +24,27 @@ def get_dataset( for dataset_attr in data_args.dataset_list: logger.info("Loading dataset {}...".format(dataset_attr)) + data_path, data_name, data_dir, data_files = None, None, None, None if dataset_attr.load_from == "hf_hub": data_path = dataset_attr.dataset_name data_name = dataset_attr.subset - data_files = None + data_dir = dataset_attr.folder elif dataset_attr.load_from == "script": data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) data_name = dataset_attr.subset - data_files = None elif dataset_attr.load_from == "file": - data_path, data_name = None, None - data_files: List[str] = [] - if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory - for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): - data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) + data_files = [] + local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + if os.path.isdir(local_path): # is directory + for file_name in os.listdir(local_path): + data_files.append(os.path.join(local_path, file_name)) if data_path is None: data_path = EXT2TYPE.get(file_name.split(".")[-1], None) else: assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." - elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file - data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) - data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) + elif os.path.isfile(local_path): # is file + data_files.append(local_path) + data_path = EXT2TYPE.get(local_path.split(".")[-1], None) else: raise ValueError("File not found.") @@ -56,6 +56,7 @@ def get_dataset( dataset = load_dataset( path=data_path, name=data_name, + data_dir=data_dir, data_files=data_files, split=data_args.split, cache_dir=model_args.cache_dir, diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index cea89198..da9be11b 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -15,6 +15,7 @@ class DatasetAttr: dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None subset: Optional[str] = None + folder: Optional[str] = None ranking: Optional[bool] = False formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" @@ -173,6 +174,7 @@ class DataArguments: dataset_attr.content = dataset_info[name]["columns"].get("content", None) dataset_attr.subset = dataset_info[name].get("subset", None) + dataset_attr.folder = dataset_info[name].get("folder", None) dataset_attr.ranking = dataset_info[name].get("ranking", False) dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") dataset_attr.system_prompt = prompt_list[i]