CPM-9G-8B/9G-Train/cpm/dataset/indexed_dataset.py

345 lines
13 KiB
Python

import itertools
import math
import os
import pickle
import queue
import random
import threading
import time
import bmtrain as bmt
try:
import msgspec
json_decode = msgspec.json.decode
json_encode = msgspec.json.encode
except ModuleNotFoundError:
import json
json_decode = json.loads
json_encode = json.dumps
import torch
from torch.utils.data import Dataset
from typing_extensions import TypedDict
from .utils import Range
print_lock = threading.Lock()
def safe_print(*args, **kargs):
if "flush" in kargs:
flush = kargs["flush"]
del kargs["flush"]
else:
flush = True
with print_lock:
print(*args, **kargs, flush=flush)
def concurrent_info():
world_size, rank = bmt.world_size(), bmt.rank()
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
nworkers, worker_id = 1, 1
else:
nworkers, worker_id = worker_info.num_workers, worker_info.id
return world_size, rank, nworkers, worker_id
class IndexedDataset(Dataset):
def __init__(self, path, max_retry=1, retry_sleep=5):
super().__init__()
self.path = path
self.max_retry = max_retry
self.retry_sleep = retry_sleep
self.bounds = None
self.build_index()
def size(self):
return self.bounds[-1]
def build_index(self):
with open(os.path.join(self.path, "index"), "r") as fin:
self.bounds = [int(line) for line in fin]
def safe_read(self, i_or_s, offset, size):
for retry in itertools.count():
try:
# destroy the file identifier to avoid pressure on alluxio
# buffering=0 to avoid overhead during file.seek() and open()
with open(os.path.join(self.path, "data.jsonl"), "rb", buffering=0) as fin:
fin.seek(offset)
raw = fin.read(size)
return raw
except OSError as e:
if retry >= self.max_retry:
raise OSError(f"reach maximum #retry: {retry}, the file system is broken.")
safe_print(
f"retry loading {self.path}:{i_or_s} in {self.retry_sleep} seconds due to error: '{repr(e)}'"
)
time.sleep(self.retry_sleep)
except ValueError as e:
# reading error during python io, skip
safe_print(f"skipping {self.path}:{i_or_s} due to error: '{repr(e)}'")
return None
def __repr__(self):
return (
f"IndexedDataset(path={self.path}, max_retry={self.max_retry}, retry_sleep={self.retry_sleep}) "
f"with {len(self)} entries."
)
def __len__(self):
return len(self.bounds) - 1
def bound_idx(self, key, strict=False):
# bound index within the standard range: [0, len(self))
# useful for tracing buggy entries
if strict and not (-len(self) <= key < len(self)):
raise IndexError(f"Index {key} out of range for '{self.path}'")
key = min(max(-len(self), key), len(self)) # bound key within [-len(self), len(self)]
key = key if key > 0 else key % len(self) # remap negative id to positive ones
return key
def __getitem__(self, key):
# supports list-like slicing and indexing. strided slicing is not currently supported.
# ok: self[1], self[-1], self[1:3], self[-10:-5], self[-10:-5:1], self[:5]
# not ok: self[-10:-5:2], self[:100:3]
if isinstance(key, slice):
if not (key.step == 1 or key.step is None):
raise ValueError(f"slice step should be 1 or None, not {key.step}")
start = self.bound_idx(0 if key.start is None else key.start)
stop = max(self.bound_idx(len(self) if key.stop is None else key.stop), start)
if stop == start:
# early returning empty slice
return list()
offset, size = self.bounds[start], self.bounds[stop] - self.bounds[start]
raw = self.safe_read(key, offset, size)
if raw is None:
return None
else:
return [
raw[s - offset : e - offset]
for s, e in zip(self.bounds[start:stop], self.bounds[start + 1 : stop + 1])
]
elif isinstance(key, int):
key = self.bound_idx(key, strict=True)
offset, size = self.bounds[key], self.bounds[key + 1] - self.bounds[key]
raw = self.safe_read(key, offset, size)
return raw
else:
raise TypeError(f"indices must be integers or slices, not {type(key)}")
class PrefetchDecodeDataset(IndexedDataset):
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
# Add safe decoding in iterator
def __init__(self, *args, decode=json_decode, **kargs):
super().__init__(*args, **kargs)
self.decode = decode
self.lock = threading.Lock()
self.prev_used = set() # store previously used index in the checkpoint
self.used = set() # track locally used index
def state_dict(self, gathered=True):
if not self.prev_used and not self.used:
return {"prev_used": set()}
if gathered:
used = torch.tensor(list(self.used)).cuda()
size = torch.tensor(used.numel()).cuda()
max_size = bmt.distributed.all_reduce(size, op="max")
# allgather requires tensors having the same size
used = torch.cat([used, torch.full((max_size - size,), -100, device=used.device)], dim=-1)
all_used = bmt.distributed.all_gather(used).unique()
all_used = set(all_used.tolist())
if -100 in all_used:
all_used.remove(-100) # remove the padding value
all_used.union(self.prev_used)
return {"prev_used": all_used}
else:
return {"prev_used": self.prev_used.union(self.used)}
def load_state_dict(self, state):
with self.lock:
self.used = state.get("prev_used", set())
def reset(self):
with self.lock:
self.used = set()
self.prev_used = set()
def safe_decode(self, i, raw):
if raw is None:
return None
try:
return self.decode(raw)
except Exception as e:
safe_print(f"Skip decoding {self.path}:{i} due to error '{e}', raw bytes:\n{raw}")
return None
def __getitem__(self, key):
raw = super().__getitem__(key)
if raw is None:
return None
# key should be either a slice or an integer as checked in IndexedDataset
if isinstance(key, slice):
return [self.safe_decode(i, r) for i, r in zip(range(key.start, key.stop), raw)]
else:
return self.safe_decode(key, raw)
def loader(self, q, lid, keys, stop):
# concurrent prefetching worker
try:
for key in keys:
if stop.is_set():
break
# key is either a slice or an integer index
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
with self.lock:
unused = set(index) - self.used - self.prev_used
if not unused:
# skip used slice / item
continue
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
raw = super().__getitem__(key)
if raw is None:
continue
# filter used data
items = [(i, s) for i, s in zip(index, raw if len(index) > 1 else [raw]) if i in unused]
random.shuffle(items)
for item in items:
q.put(item)
finally:
# signaling the end of iteration to the main thread
q.put(StopIteration(lid))
def _iterate(self, key_groups, nprefetch=1000):
# helper function for concurrent prefetching
q = queue.Queue(maxsize=nprefetch)
stop = threading.Event()
alive = set()
try:
for lid, keys in enumerate(key_groups):
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop), daemon=True)
loader.start()
alive.add(lid)
while True:
try:
item = q.get(block=False)
except queue.Empty:
if not alive:
# no alive loader, thus no item will be put in the queue
break
else:
# new item will be put later, wait for a while
time.sleep(0.3)
continue
if isinstance(item, StopIteration):
alive.remove(item.value)
continue
i, raw = item
data = self.safe_decode(i, raw)
if data is None:
continue
self.used.add(i)
yield data
# automatically reset states with graceful ends.
self.reset()
finally:
# ask daemon loaders to stop
stop.set()
def iterate(self, nthreads=3, prefetch_sample=100):
world_size, rank, nworkers, worker_id = concurrent_info()
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
raise ValueError(
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
f"please constrain either "
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
)
r = Range(0, len(self), 1)
# split index among multi-gpu workers
r = r.subrange(split=rank, nsplits=world_size)
# split index among multi-process dataloader workers
r = r.subrange(split=worker_id, nsplits=nworkers)
# split index among multi-threaded loaders
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
for data in self._iterate(id_groups, nprefetch=prefetch_sample):
yield data
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=1000):
world_size, rank, nworkers, worker_id = concurrent_info()
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
raise ValueError(
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
f"please constrain either "
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
)
nslices = int(math.ceil(len(self) / slice_size))
if nslices < nloaders:
safe_print(
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
)
slice_size = len(self) // nloaders
# we only iteratre through start ids as they uniquely mark each slice
r = Range(0, len(self), slice_size)
# split index among multi-gpu workers
r = r.subrange(split=rank, nsplits=world_size)
# split index among multi-process dataloader workers
r = r.subrange(split=worker_id, nsplits=nworkers)
# split index among multi-threaded loaders
slice_groups = [
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
]
for data in self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size):
yield data
class IndexedDatasetBuilder:
def __init__(self, path, overwrite=False):
self.path = path
self.index_path = os.path.join(self.path, "index")
self.data_path = os.path.join(self.path, "data.jsonl")
if not overwrite:
assert not os.path.exists(self.data_path)
assert not os.path.exists(self.index_path)
self.fout = None
self.starts = []
self.offset = 0
def __enter__(self):
os.makedirs(self.path, exist_ok=True)
self.fout = open(self.data_path, "wb")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.starts.append(self.offset)
with open(self.index_path, "w") as fout:
for s in self.starts:
fout.write(f"{s}\n")
self.fout.close()
def put(self, data: dict):
s = json_encode(data) + b"\n"
self.starts.append(self.offset)
self.offset += len(s)
self.fout.write(s)
if __name__ == "__main__":
with IndexedDatasetBuilder("swear", overwrite=True) as builder:
for d in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
builder.put(d)
dataset = IndexedDataset("swear")
for i in range(10):
print(dataset[random.randint(0, len(dataset) - 1)])