forked from jiuyuan/CPM-9G-8B
345 lines
13 KiB
Python
345 lines
13 KiB
Python
|
import itertools
|
||
|
import math
|
||
|
import os
|
||
|
import pickle
|
||
|
import queue
|
||
|
import random
|
||
|
import threading
|
||
|
import time
|
||
|
|
||
|
import bmtrain as bmt
|
||
|
|
||
|
try:
|
||
|
import msgspec
|
||
|
|
||
|
json_decode = msgspec.json.decode
|
||
|
json_encode = msgspec.json.encode
|
||
|
except ModuleNotFoundError:
|
||
|
import json
|
||
|
|
||
|
json_decode = json.loads
|
||
|
json_encode = json.dumps
|
||
|
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
from typing_extensions import TypedDict
|
||
|
|
||
|
from .utils import Range
|
||
|
|
||
|
print_lock = threading.Lock()
|
||
|
|
||
|
|
||
|
def safe_print(*args, **kargs):
|
||
|
if "flush" in kargs:
|
||
|
flush = kargs["flush"]
|
||
|
del kargs["flush"]
|
||
|
else:
|
||
|
flush = True
|
||
|
with print_lock:
|
||
|
print(*args, **kargs, flush=flush)
|
||
|
|
||
|
|
||
|
def concurrent_info():
|
||
|
world_size, rank = bmt.world_size(), bmt.rank()
|
||
|
worker_info = torch.utils.data.get_worker_info()
|
||
|
if worker_info is None:
|
||
|
nworkers, worker_id = 1, 1
|
||
|
else:
|
||
|
nworkers, worker_id = worker_info.num_workers, worker_info.id
|
||
|
return world_size, rank, nworkers, worker_id
|
||
|
|
||
|
|
||
|
class IndexedDataset(Dataset):
|
||
|
def __init__(self, path, max_retry=1, retry_sleep=5):
|
||
|
super().__init__()
|
||
|
self.path = path
|
||
|
self.max_retry = max_retry
|
||
|
self.retry_sleep = retry_sleep
|
||
|
self.bounds = None
|
||
|
self.build_index()
|
||
|
|
||
|
def size(self):
|
||
|
return self.bounds[-1]
|
||
|
|
||
|
def build_index(self):
|
||
|
with open(os.path.join(self.path, "index"), "r") as fin:
|
||
|
self.bounds = [int(line) for line in fin]
|
||
|
|
||
|
def safe_read(self, i_or_s, offset, size):
|
||
|
for retry in itertools.count():
|
||
|
try:
|
||
|
# destroy the file identifier to avoid pressure on alluxio
|
||
|
# buffering=0 to avoid overhead during file.seek() and open()
|
||
|
with open(os.path.join(self.path, "data.jsonl"), "rb", buffering=0) as fin:
|
||
|
fin.seek(offset)
|
||
|
raw = fin.read(size)
|
||
|
return raw
|
||
|
except OSError as e:
|
||
|
if retry >= self.max_retry:
|
||
|
raise OSError(f"reach maximum #retry: {retry}, the file system is broken.")
|
||
|
safe_print(
|
||
|
f"retry loading {self.path}:{i_or_s} in {self.retry_sleep} seconds due to error: '{repr(e)}'"
|
||
|
)
|
||
|
time.sleep(self.retry_sleep)
|
||
|
except ValueError as e:
|
||
|
# reading error during python io, skip
|
||
|
safe_print(f"skipping {self.path}:{i_or_s} due to error: '{repr(e)}'")
|
||
|
return None
|
||
|
|
||
|
def __repr__(self):
|
||
|
return (
|
||
|
f"IndexedDataset(path={self.path}, max_retry={self.max_retry}, retry_sleep={self.retry_sleep}) "
|
||
|
f"with {len(self)} entries."
|
||
|
)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.bounds) - 1
|
||
|
|
||
|
def bound_idx(self, key, strict=False):
|
||
|
# bound index within the standard range: [0, len(self))
|
||
|
# useful for tracing buggy entries
|
||
|
if strict and not (-len(self) <= key < len(self)):
|
||
|
raise IndexError(f"Index {key} out of range for '{self.path}'")
|
||
|
key = min(max(-len(self), key), len(self)) # bound key within [-len(self), len(self)]
|
||
|
key = key if key > 0 else key % len(self) # remap negative id to positive ones
|
||
|
return key
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
# supports list-like slicing and indexing. strided slicing is not currently supported.
|
||
|
# ok: self[1], self[-1], self[1:3], self[-10:-5], self[-10:-5:1], self[:5]
|
||
|
# not ok: self[-10:-5:2], self[:100:3]
|
||
|
if isinstance(key, slice):
|
||
|
if not (key.step == 1 or key.step is None):
|
||
|
raise ValueError(f"slice step should be 1 or None, not {key.step}")
|
||
|
start = self.bound_idx(0 if key.start is None else key.start)
|
||
|
stop = max(self.bound_idx(len(self) if key.stop is None else key.stop), start)
|
||
|
if stop == start:
|
||
|
# early returning empty slice
|
||
|
return list()
|
||
|
offset, size = self.bounds[start], self.bounds[stop] - self.bounds[start]
|
||
|
raw = self.safe_read(key, offset, size)
|
||
|
if raw is None:
|
||
|
return None
|
||
|
else:
|
||
|
return [
|
||
|
raw[s - offset : e - offset]
|
||
|
for s, e in zip(self.bounds[start:stop], self.bounds[start + 1 : stop + 1])
|
||
|
]
|
||
|
|
||
|
elif isinstance(key, int):
|
||
|
key = self.bound_idx(key, strict=True)
|
||
|
offset, size = self.bounds[key], self.bounds[key + 1] - self.bounds[key]
|
||
|
raw = self.safe_read(key, offset, size)
|
||
|
return raw
|
||
|
else:
|
||
|
raise TypeError(f"indices must be integers or slices, not {type(key)}")
|
||
|
|
||
|
|
||
|
class PrefetchDecodeDataset(IndexedDataset):
|
||
|
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
|
||
|
# Add safe decoding in iterator
|
||
|
def __init__(self, *args, decode=json_decode, **kargs):
|
||
|
super().__init__(*args, **kargs)
|
||
|
self.decode = decode
|
||
|
self.lock = threading.Lock()
|
||
|
self.prev_used = set() # store previously used index in the checkpoint
|
||
|
self.used = set() # track locally used index
|
||
|
|
||
|
def state_dict(self, gathered=True):
|
||
|
if not self.prev_used and not self.used:
|
||
|
return {"prev_used": set()}
|
||
|
if gathered:
|
||
|
used = torch.tensor(list(self.used)).cuda()
|
||
|
size = torch.tensor(used.numel()).cuda()
|
||
|
max_size = bmt.distributed.all_reduce(size, op="max")
|
||
|
# allgather requires tensors having the same size
|
||
|
used = torch.cat([used, torch.full((max_size - size,), -100, device=used.device)], dim=-1)
|
||
|
all_used = bmt.distributed.all_gather(used).unique()
|
||
|
all_used = set(all_used.tolist())
|
||
|
if -100 in all_used:
|
||
|
all_used.remove(-100) # remove the padding value
|
||
|
all_used.union(self.prev_used)
|
||
|
return {"prev_used": all_used}
|
||
|
else:
|
||
|
return {"prev_used": self.prev_used.union(self.used)}
|
||
|
|
||
|
def load_state_dict(self, state):
|
||
|
with self.lock:
|
||
|
self.used = state.get("prev_used", set())
|
||
|
|
||
|
def reset(self):
|
||
|
with self.lock:
|
||
|
self.used = set()
|
||
|
self.prev_used = set()
|
||
|
|
||
|
def safe_decode(self, i, raw):
|
||
|
if raw is None:
|
||
|
return None
|
||
|
try:
|
||
|
return self.decode(raw)
|
||
|
except Exception as e:
|
||
|
safe_print(f"Skip decoding {self.path}:{i} due to error '{e}', raw bytes:\n{raw}")
|
||
|
return None
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
raw = super().__getitem__(key)
|
||
|
if raw is None:
|
||
|
return None
|
||
|
# key should be either a slice or an integer as checked in IndexedDataset
|
||
|
if isinstance(key, slice):
|
||
|
return [self.safe_decode(i, r) for i, r in zip(range(key.start, key.stop), raw)]
|
||
|
else:
|
||
|
return self.safe_decode(key, raw)
|
||
|
|
||
|
def loader(self, q, lid, keys, stop):
|
||
|
# concurrent prefetching worker
|
||
|
try:
|
||
|
for key in keys:
|
||
|
if stop.is_set():
|
||
|
break
|
||
|
# key is either a slice or an integer index
|
||
|
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
|
||
|
with self.lock:
|
||
|
unused = set(index) - self.used - self.prev_used
|
||
|
if not unused:
|
||
|
# skip used slice / item
|
||
|
continue
|
||
|
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
|
||
|
raw = super().__getitem__(key)
|
||
|
if raw is None:
|
||
|
continue
|
||
|
# filter used data
|
||
|
items = [(i, s) for i, s in zip(index, raw if len(index) > 1 else [raw]) if i in unused]
|
||
|
random.shuffle(items)
|
||
|
for item in items:
|
||
|
q.put(item)
|
||
|
finally:
|
||
|
# signaling the end of iteration to the main thread
|
||
|
q.put(StopIteration(lid))
|
||
|
|
||
|
def _iterate(self, key_groups, nprefetch=1000):
|
||
|
# helper function for concurrent prefetching
|
||
|
q = queue.Queue(maxsize=nprefetch)
|
||
|
stop = threading.Event()
|
||
|
alive = set()
|
||
|
try:
|
||
|
for lid, keys in enumerate(key_groups):
|
||
|
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop), daemon=True)
|
||
|
loader.start()
|
||
|
alive.add(lid)
|
||
|
while True:
|
||
|
try:
|
||
|
item = q.get(block=False)
|
||
|
except queue.Empty:
|
||
|
if not alive:
|
||
|
# no alive loader, thus no item will be put in the queue
|
||
|
break
|
||
|
else:
|
||
|
# new item will be put later, wait for a while
|
||
|
time.sleep(0.3)
|
||
|
continue
|
||
|
if isinstance(item, StopIteration):
|
||
|
alive.remove(item.value)
|
||
|
continue
|
||
|
i, raw = item
|
||
|
data = self.safe_decode(i, raw)
|
||
|
if data is None:
|
||
|
continue
|
||
|
self.used.add(i)
|
||
|
yield data
|
||
|
# automatically reset states with graceful ends.
|
||
|
self.reset()
|
||
|
finally:
|
||
|
# ask daemon loaders to stop
|
||
|
stop.set()
|
||
|
|
||
|
def iterate(self, nthreads=3, prefetch_sample=100):
|
||
|
world_size, rank, nworkers, worker_id = concurrent_info()
|
||
|
nloaders = world_size * nworkers * nthreads
|
||
|
if len(self) < nloaders:
|
||
|
raise ValueError(
|
||
|
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||
|
f"please constrain either "
|
||
|
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||
|
)
|
||
|
r = Range(0, len(self), 1)
|
||
|
# split index among multi-gpu workers
|
||
|
r = r.subrange(split=rank, nsplits=world_size)
|
||
|
# split index among multi-process dataloader workers
|
||
|
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||
|
# split index among multi-threaded loaders
|
||
|
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
|
||
|
for data in self._iterate(id_groups, nprefetch=prefetch_sample):
|
||
|
yield data
|
||
|
|
||
|
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=1000):
|
||
|
world_size, rank, nworkers, worker_id = concurrent_info()
|
||
|
nloaders = world_size * nworkers * nthreads
|
||
|
if len(self) < nloaders:
|
||
|
raise ValueError(
|
||
|
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||
|
f"please constrain either "
|
||
|
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||
|
)
|
||
|
nslices = int(math.ceil(len(self) / slice_size))
|
||
|
|
||
|
if nslices < nloaders:
|
||
|
safe_print(
|
||
|
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
|
||
|
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
|
||
|
)
|
||
|
slice_size = len(self) // nloaders
|
||
|
|
||
|
# we only iteratre through start ids as they uniquely mark each slice
|
||
|
r = Range(0, len(self), slice_size)
|
||
|
# split index among multi-gpu workers
|
||
|
r = r.subrange(split=rank, nsplits=world_size)
|
||
|
# split index among multi-process dataloader workers
|
||
|
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||
|
# split index among multi-threaded loaders
|
||
|
slice_groups = [
|
||
|
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
|
||
|
]
|
||
|
for data in self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size):
|
||
|
yield data
|
||
|
|
||
|
|
||
|
class IndexedDatasetBuilder:
|
||
|
def __init__(self, path, overwrite=False):
|
||
|
self.path = path
|
||
|
self.index_path = os.path.join(self.path, "index")
|
||
|
self.data_path = os.path.join(self.path, "data.jsonl")
|
||
|
if not overwrite:
|
||
|
assert not os.path.exists(self.data_path)
|
||
|
assert not os.path.exists(self.index_path)
|
||
|
self.fout = None
|
||
|
self.starts = []
|
||
|
self.offset = 0
|
||
|
|
||
|
def __enter__(self):
|
||
|
os.makedirs(self.path, exist_ok=True)
|
||
|
self.fout = open(self.data_path, "wb")
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
|
self.starts.append(self.offset)
|
||
|
with open(self.index_path, "w") as fout:
|
||
|
for s in self.starts:
|
||
|
fout.write(f"{s}\n")
|
||
|
self.fout.close()
|
||
|
|
||
|
def put(self, data: dict):
|
||
|
s = json_encode(data) + b"\n"
|
||
|
self.starts.append(self.offset)
|
||
|
self.offset += len(s)
|
||
|
self.fout.write(s)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
with IndexedDatasetBuilder("swear", overwrite=True) as builder:
|
||
|
for d in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
|
||
|
builder.put(d)
|
||
|
dataset = IndexedDataset("swear")
|
||
|
for i in range(10):
|
||
|
print(dataset[random.randint(0, len(dataset) - 1)])
|