This commit is contained in:
hiyouga 2023-12-23 14:42:20 +08:00
parent 0ad86a4f62
commit 0bbf7118df
2 changed files with 9 additions and 9 deletions

View File

@ -22,12 +22,15 @@ def get_dataset(
max_samples = data_args.max_samples
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.cache_path is not None:
if os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.cache_path)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
elif data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))

View File

@ -127,9 +127,6 @@ class DataArguments:
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.streaming and self.cache_path:
raise ValueError("`cache_path` is incompatible with `streaming`.")
def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []