Merge pull request #3829 from seanzhang-zhichen/add_dataset_sample_num
Add dataset sample num
This commit is contained in:
commit
483eb47e5d
|
@ -12,6 +12,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
|||
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
|
||||
"columns (optional)": {
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
||||
"columns(可选)": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
|
@ -106,9 +107,21 @@ def load_single_dataset(
|
|||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||
target_num -= len(indexes)
|
||||
if target_num > 0:
|
||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
||||
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
indexes = np.random.permutation(len(dataset))[: data_args.max_samples]
|
||||
dataset = dataset.select(indexes)
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
|
|
|
@ -20,11 +20,12 @@ class DatasetAttr:
|
|||
""" basic configs """
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
""" extra configs """
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
num_samples: Optional[int] = None
|
||||
""" common columns """
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
|
@ -102,10 +103,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
|
|
Loading…
Reference in New Issue