Update loader.py
This commit is contained in:
parent
b55fb611c5
commit
f9a88b89ca
|
@ -108,7 +108,13 @@ def load_single_dataset(
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||||
indexes = np.random.permutation(len(dataset))[: dataset_attr.num_samples]
|
target_num = dataset_attr.num_samples
|
||||||
|
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||||
|
target_num -= len(indexes)
|
||||||
|
if target_num > 0:
|
||||||
|
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||||
|
indexes = np.concatenate((indexes, expand_indexes), axis=0)
|
||||||
|
|
||||||
dataset = dataset.select(indexes)
|
dataset = dataset.select(indexes)
|
||||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue