Update loader.py

This commit is contained in:
hoshi-hiyouga 2024-05-30 00:17:21 +08:00 committed by GitHub
parent b55fb611c5
commit f9a88b89ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 1 deletions

View File

@ -108,7 +108,13 @@ def load_single_dataset(
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples] target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num]
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
indexes = np.concatenate((indexes, expand_indexes), axis=0)
dataset = dataset.select(indexes) dataset = dataset.select(indexes)
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr)) logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))