From f9a88b89ca8b8f9a0c5def03b154f9d67f558edf Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 30 May 2024 00:17:21 +0800 Subject: [PATCH] Update loader.py --- src/llamafactory/data/loader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 322eefa0..fa5b12c5 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -108,7 +108,13 @@ def load_single_dataset( dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter if dataset_attr.num_samples is not None and not data_args.streaming: - indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples] + target_num = dataset_attr.num_samples + indexes = np.random.permutation(len(dataset))[:target_num] + target_num -= len(indexes) + if target_num > 0: + expand_indexes = np.random.choice(len(dataset), target_num) + indexes = np.concatenate((indexes, expand_indexes), axis=0) + dataset = dataset.select(indexes) logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))