Supports custom data set sampling quantity
This commit is contained in:
parent
3bcd41b639
commit
449e2aa38e
|
@ -27,8 +27,9 @@ If you are using a custom dataset, please provide your dataset definition in the
|
|||
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
|
||||
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
|
||||
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
|
||||
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
|
||||
}
|
||||
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)",
|
||||
},
|
||||
"sample_num": "the number of samples from this dataset can be greater than the total amount of the dataset. (default: None)"
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -28,7 +28,8 @@
|
|||
"observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
|
||||
"function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
|
||||
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
|
||||
}
|
||||
},
|
||||
"sample_num": "从该数据集采样的数量,可大于该数据集总量(默认:None)"
|
||||
}
|
||||
```
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import inspect
|
||||
import os
|
||||
import numpy as np
|
||||
from numpy.random import RandomState
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
@ -108,6 +110,17 @@ def load_single_dataset(
|
|||
num_samples = min(data_args.max_samples, len(dataset))
|
||||
dataset = dataset.select(range(num_samples))
|
||||
|
||||
if dataset_attr.sample_num:
|
||||
dataset_sample_num = dataset_attr.sample_num
|
||||
logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本")
|
||||
random_state = RandomState(42)
|
||||
idx = random_state.permutation(len(dataset))[:dataset_sample_num]
|
||||
dataset_sample_num -= len(idx)
|
||||
if dataset_sample_num > 0:
|
||||
idx2 = random_state.choice(len(dataset), dataset_sample_num)
|
||||
idx = np.concatenate([idx, idx2], axis=0)
|
||||
dataset = dataset.select(idx)
|
||||
|
||||
return align_dataset(dataset, dataset_attr, data_args)
|
||||
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ class DatasetAttr:
|
|||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = "system"
|
||||
sample_num: Optional[int] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
@ -90,7 +91,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
dataset_attr.set_attr("sample_num", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
|
|
Loading…
Reference in New Issue