Supports custom data set sampling quantity

This commit is contained in:
zhangzc 2024-03-27 14:22:50 +08:00
parent 3bcd41b639
commit 449e2aa38e
4 changed files with 21 additions and 4 deletions

View File

@ -27,8 +27,9 @@ If you are using a custom dataset, please provide your dataset definition in the
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)", "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)", "observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
"function_tag": "the value of the role_tag represents the function call. (default: function_call)", "function_tag": "the value of the role_tag represents the function call. (default: function_call)",
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)" "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)",
} },
"sample_num": "the number of samples from this dataset can be greater than the total amount of the dataset. (default: None)"
} }
``` ```

View File

@ -28,7 +28,8 @@
"observation_tag": "消息中代表工具返回结果的 role_tag默认observation", "observation_tag": "消息中代表工具返回结果的 role_tag默认observation",
"function_tag": "消息中代表工具调用的 role_tag默认function_call", "function_tag": "消息中代表工具调用的 role_tag默认function_call",
"system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system 列)" "system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system 列)"
} },
"sample_num": "从该数据集采样的数量可大于该数据集总量默认None"
} }
``` ```

View File

@ -1,5 +1,7 @@
import inspect import inspect
import os import os
import numpy as np
from numpy.random import RandomState
from typing import TYPE_CHECKING, Literal, Union from typing import TYPE_CHECKING, Literal, Union
from datasets import load_dataset, load_from_disk from datasets import load_dataset, load_from_disk
@ -108,6 +110,17 @@ def load_single_dataset(
num_samples = min(data_args.max_samples, len(dataset)) num_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(num_samples)) dataset = dataset.select(range(num_samples))
if dataset_attr.sample_num:
dataset_sample_num = dataset_attr.sample_num
logger.info(f"{dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本")
random_state = RandomState(42)
idx = random_state.permutation(len(dataset))[:dataset_sample_num]
dataset_sample_num -= len(idx)
if dataset_sample_num > 0:
idx2 = random_state.choice(len(dataset), dataset_sample_num)
idx = np.concatenate([idx, idx2], axis=0)
dataset = dataset.select(idx)
return align_dataset(dataset, dataset_attr, data_args) return align_dataset(dataset, dataset_attr, data_args)

View File

@ -44,6 +44,7 @@ class DatasetAttr:
observation_tag: Optional[str] = "observation" observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call" function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system" system_tag: Optional[str] = "system"
sample_num: Optional[int] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
@ -90,7 +91,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("sample_num", dataset_info[name])
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system"] column_names = ["system"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":