diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index fa5b12c5..d4a19e27 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -115,6 +115,7 @@ def load_single_dataset( expand_indexes = np.random.choice(len(dataset), target_num) indexes = np.concatenate((indexes, expand_indexes), axis=0) + assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." dataset = dataset.select(indexes) logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))