This commit is contained in:
hiyouga 2023-12-09 20:53:18 +08:00
parent d42c0b1d34
commit 28d5de7e78
5 changed files with 21 additions and 15 deletions

View File

@ -4,9 +4,10 @@ If you are using a custom dataset, please provide your dataset definition in the
"dataset_name": {
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_name": "the name of the dataset file in this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
"subset": "the name of the subset. (optional, default: None)",
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns": {

View File

@ -2,11 +2,12 @@
```json
"数据集名称": {
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
"hf_hub_url": "Hugging Face 的仓库地址(若指定,则忽略下列三个参数)",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选留空不影响训练",
"file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
"subset": "数据集子集的名称可选默认None",
"folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns": {

View File

@ -274,10 +274,11 @@
"prompt": "content"
}
},
"starcoder": {
"starcoder_python": {
"hf_hub_url": "bigcode/starcoderdata",
"columns": {
"prompt": "content"
}
},
"folder": "python"
}
}

View File

@ -24,27 +24,27 @@ def get_dataset(
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_files = None
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_files = None
elif dataset_attr.load_from == "file":
data_path, data_name = None, None
data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
data_files = []
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = EXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
@ -56,6 +56,7 @@ def get_dataset(
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,

View File

@ -15,6 +15,7 @@ class DatasetAttr:
dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None
subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
@ -173,6 +174,7 @@ class DataArguments:
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
dataset_attr.system_prompt = prompt_list[i]