Supports custom data set sampling quantity
This commit is contained in:
parent
3bcd41b639
commit
449e2aa38e
|
@ -27,8 +27,9 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||||
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
|
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
|
||||||
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
|
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
|
||||||
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
|
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
|
||||||
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
|
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)",
|
||||||
}
|
},
|
||||||
|
"sample_num": "the number of samples from this dataset can be greater than the total amount of the dataset. (default: None)"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,8 @@
|
||||||
"observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
|
"observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
|
||||||
"function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
|
"function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
|
||||||
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
|
"system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
|
||||||
}
|
},
|
||||||
|
"sample_num": "从该数据集采样的数量,可大于该数据集总量(默认:None)"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from numpy.random import RandomState
|
||||||
from typing import TYPE_CHECKING, Literal, Union
|
from typing import TYPE_CHECKING, Literal, Union
|
||||||
|
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import load_dataset, load_from_disk
|
||||||
|
@ -108,6 +110,17 @@ def load_single_dataset(
|
||||||
num_samples = min(data_args.max_samples, len(dataset))
|
num_samples = min(data_args.max_samples, len(dataset))
|
||||||
dataset = dataset.select(range(num_samples))
|
dataset = dataset.select(range(num_samples))
|
||||||
|
|
||||||
|
if dataset_attr.sample_num:
|
||||||
|
dataset_sample_num = dataset_attr.sample_num
|
||||||
|
logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本")
|
||||||
|
random_state = RandomState(42)
|
||||||
|
idx = random_state.permutation(len(dataset))[:dataset_sample_num]
|
||||||
|
dataset_sample_num -= len(idx)
|
||||||
|
if dataset_sample_num > 0:
|
||||||
|
idx2 = random_state.choice(len(dataset), dataset_sample_num)
|
||||||
|
idx = np.concatenate([idx, idx2], axis=0)
|
||||||
|
dataset = dataset.select(idx)
|
||||||
|
|
||||||
return align_dataset(dataset, dataset_attr, data_args)
|
return align_dataset(dataset, dataset_attr, data_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ class DatasetAttr:
|
||||||
observation_tag: Optional[str] = "observation"
|
observation_tag: Optional[str] = "observation"
|
||||||
function_tag: Optional[str] = "function_call"
|
function_tag: Optional[str] = "function_call"
|
||||||
system_tag: Optional[str] = "system"
|
system_tag: Optional[str] = "system"
|
||||||
|
sample_num: Optional[int] = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
|
@ -90,7 +91,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
dataset_attr.set_attr("folder", dataset_info[name])
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
|
dataset_attr.set_attr("sample_num", dataset_info[name])
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
column_names = ["system"]
|
column_names = ["system"]
|
||||||
if dataset_attr.formatting == "alpaca":
|
if dataset_attr.formatting == "alpaca":
|
||||||
|
|
Loading…
Reference in New Issue