From 449e2aa38e3a6cf301a43c12c121ac24ebf12027 Mon Sep 17 00:00:00 2001 From: zhangzc <2608882093@qq.com> Date: Wed, 27 Mar 2024 14:22:50 +0800 Subject: [PATCH] Supports custom data set sampling quantity --- data/README.md | 5 +++-- data/README_zh.md | 3 ++- src/llmtuner/data/loader.py | 13 +++++++++++++ src/llmtuner/data/parser.py | 4 +++- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/data/README.md b/data/README.md index fa2c9ee0..c4a1b298 100644 --- a/data/README.md +++ b/data/README.md @@ -27,8 +27,9 @@ If you are using a custom dataset, please provide your dataset definition in the "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)", "observation_tag": "the value of the role_tag represents the tool results. (default: observation)", "function_tag": "the value of the role_tag represents the function call. (default: function_call)", - "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)" - } + "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)", + }, + "sample_num": "the number of samples from this dataset can be greater than the total amount of the dataset. (default: None)" } ``` diff --git a/data/README_zh.md b/data/README_zh.md index e0004f4a..6396688a 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -28,7 +28,8 @@ "observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)", "function_tag": "消息中代表工具调用的 role_tag(默认:function_call)", "system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)" - } + }, + "sample_num": "从该数据集采样的数量,可大于该数据集总量(默认:None)" } ``` diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 935695ad..bebe5718 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -1,5 +1,7 @@ import inspect import os +import numpy as np +from numpy.random import RandomState from typing import TYPE_CHECKING, Literal, Union from datasets import load_dataset, load_from_disk @@ -108,6 +110,17 @@ def load_single_dataset( num_samples = min(data_args.max_samples, len(dataset)) dataset = dataset.select(range(num_samples)) + if dataset_attr.sample_num: + dataset_sample_num = dataset_attr.sample_num + logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本") + random_state = RandomState(42) + idx = random_state.permutation(len(dataset))[:dataset_sample_num] + dataset_sample_num -= len(idx) + if dataset_sample_num > 0: + idx2 = random_state.choice(len(dataset), dataset_sample_num) + idx = np.concatenate([idx, idx2], axis=0) + dataset = dataset.select(idx) + return align_dataset(dataset, dataset_attr, data_args) diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 861396a0..9746b5b2 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -44,6 +44,7 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" + sample_num: Optional[int] = None def __repr__(self) -> str: return self.dataset_name @@ -90,7 +91,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") - + dataset_attr.set_attr("sample_num", dataset_info[name]) + if "columns" in dataset_info[name]: column_names = ["system"] if dataset_attr.formatting == "alpaca":