This commit is contained in:
hiyouga 2023-12-09 20:53:18 +08:00
parent d42c0b1d34
commit 28d5de7e78
5 changed files with 21 additions and 15 deletions

View File

@ -4,9 +4,10 @@ If you are using a custom dataset, please provide your dataset definition in the
"dataset_name": { "dataset_name": {
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)", "hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)", "file_name": "the name of the dataset file in this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)", "file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
"subset": "the name of the subset. (optional, default: None)", "subset": "the name of the subset. (optional, default: None)",
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)", "ranking": "whether the dataset is a preference dataset or not. (default: false)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns": { "columns": {

View File

@ -2,11 +2,12 @@
```json ```json
"数据集名称": { "数据集名称": {
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)", "hf_hub_url": "Hugging Face 的仓库地址(若指定,则忽略下列三个参数)",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)", "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选留空不影响训练", "file_sha1": "数据集文件的 SHA-1 哈希值(可选,留空不影响训练)",
"subset": "数据集子集的名称可选默认None", "subset": "数据集子集的名称可选默认None",
"folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集可选默认False", "ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt", "formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns": { "columns": {

View File

@ -274,10 +274,11 @@
"prompt": "content" "prompt": "content"
} }
}, },
"starcoder": { "starcoder_python": {
"hf_hub_url": "bigcode/starcoderdata", "hf_hub_url": "bigcode/starcoderdata",
"columns": { "columns": {
"prompt": "content" "prompt": "content"
} },
"folder": "python"
} }
} }

View File

@ -24,27 +24,27 @@ def get_dataset(
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from == "hf_hub": if dataset_attr.load_from == "hf_hub":
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_files = None data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script": elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_files = None
elif dataset_attr.load_from == "file": elif dataset_attr.load_from == "file":
data_path, data_name = None, None data_files = []
data_files: List[str] = [] local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): for file_name in os.listdir(local_path):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None: if data_path is None:
data_path = EXT2TYPE.get(file_name.split(".")[-1], None) data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else: else:
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file elif os.path.isfile(local_path): # is file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) data_files.append(local_path)
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) data_path = EXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError("File not found.") raise ValueError("File not found.")
@ -56,6 +56,7 @@ def get_dataset(
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
name=data_name, name=data_name,
data_dir=data_dir,
data_files=data_files, data_files=data_files,
split=data_args.split, split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,

View File

@ -15,6 +15,7 @@ class DatasetAttr:
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None
ranking: Optional[bool] = False ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
@ -173,6 +174,7 @@ class DataArguments:
dataset_attr.content = dataset_info[name]["columns"].get("content", None) dataset_attr.content = dataset_info[name]["columns"].get("content", None)
dataset_attr.subset = dataset_info[name].get("subset", None) dataset_attr.subset = dataset_info[name].get("subset", None)
dataset_attr.folder = dataset_info[name].get("folder", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False) dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
dataset_attr.system_prompt = prompt_list[i] dataset_attr.system_prompt = prompt_list[i]