forked from p83651209/CPM-9G-2B
461 lines
18 KiB
461 lines
18 KiB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright @2023 AI, ZHIHU Inc. (
# @author: ouzebin <>
# @date: 2023/09/27
使用 IndexedDataset 前需按指定格式构建或者转换已有数据集
- <dataset name>
- data.jsonl # jsonl 格式的数据,每一行一条样本
- index # 记录每一行 json 数据的起始 byte-offset
从头构建:直接使用 IndexedDatasetBuilder 这个 context manager
>>> with IndexedDatasetBuilder("swear", overwrite=True) as builder:
>>> for data in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
>>> builder.put(data)
从 fm9g distributed_dataset 转换:使用 ``
$ python -m -i <原数据集文件夹> -o <新数据集文件夹>
已有 jsonl 数据:使用 `` 构建 index 文件。需提前先把 jsonl 文件命名为
$ python -m -p <数据集文件夹路径>
import itertools
import math
import os
import queue
import random
import threading
import time
import bmtrain as bmt
import h5py
import numpy
import numpy as np
import torch
import msgspec
json_decode = msgspec.json.decode
json_encode = msgspec.json.encode
except ModuleNotFoundError:
import json
json_decode = json.loads
json_encode = json.dumps
import torch
from import Dataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.bitset import bitset_diff
print_lock = threading.Lock()
def random_range(start, stop=None, step=None):
Generator of non-repeated random permutation with the same inteface of python
`range`. Obtained from
The random.shuffle(list) and random.sample(list, len(list)) require
materialize the lists, which result in a long initalization period.
if stop is None:
start, stop = 0, start
if step is None:
step = 1
# Use a mapping to convert a standard range into the desired range.
mapping = lambda i: (i * step) + start
# Compute the number of numbers in this range.
maximum = int(math.ceil((stop - start) / step))
if maximum == 0:
# early return with empty range
yield from ()
# Seed range with a random integer.
value = random.randint(0, maximum)
# Construct an offset, multiplier, and modulus for a linear
# congruential generator. These generators are cyclic and
# non-repeating when they maintain the properties:
# 1) "modulus" and "offset" are relatively prime.
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
# Pick a random odd-valued offset.
offset = random.randint(0, maximum) * 2 + 1
# Pick a multiplier 1 greater than a multiple of 4.
multiplier = 4 * (maximum // 4) + 1
# Pick a modulus just big enough to generate all numbers (power of 2).
modulus = int(2 ** math.ceil(math.log2(maximum)))
# Track how many random numbers have been returned.
found = 0
while found < maximum:
# If this is a valid value, yield it in generator fashion.
if value < maximum:
found += 1
yield mapping(value)
# Calculate the next value in the sequence.
value = (value * multiplier + offset) % modulus
class Range(object):
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
self.step = step
def __repr__(self):
return f"Range({self.start}, {self.stop}, {self.step})"
def iterate(self):
yield from range(self.start, self.stop, self.step)
def list(self):
return list(range(self.start, self.stop, self.step))
def subrange(self, split, nsplits):
# strided spliting range params
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
def random_iterate(self):
yield from random_range(self.start, self.stop, self.step)
def safe_print(*args, **kargs):
if "flush" in kargs:
flush = kargs["flush"]
del kargs["flush"]
flush = True
with print_lock:
print(*args, **kargs, flush=flush)
def concurrent_info():
# world_size, rank = bmt.world_size(), bmt.rank()
world_size = bmt.config["world_size"] // bmt.config["tp_size"]
rank = bmt.config["topology"].tp_idx
worker_info =
if worker_info is None:
nworkers, worker_id = 1, 1
nworkers, worker_id = worker_info.num_workers,
# print("concurrent_info: (world_size, rank, nworkers, worker_id): {}".format((world_size, rank, nworkers, worker_id)))
return world_size, rank, nworkers, worker_id
class IndexedDataset(Dataset):
def __init__(self, path, max_retry=1, retry_sleep=5):
self.path = path
self.max_retry = max_retry
self.retry_sleep = retry_sleep
self.bounds = None
self.h5file = None
def size(self):
return self.bounds[-1]
def _build_index_h5(self):
index_path = os.path.join(self.path, "index.h5")
if os.path.getsize(index_path) > 104857600:
self.h5file = h5py.File(os.path.join(self.path, "index.h5"), "r")
self.bounds = self.h5file["index"]
# only load index into memory when it is small (< 100 Mb)
# to avoid keeping to many file handlers
self.h5file = None
with h5py.File(index_path, "r") as hf:
self.bounds = np.array(hf["index"])
def __del__(self):
if self.h5file is not None:
def build_index(self):
s = time.time()
txt_size = os.path.getsize(os.path.join(self.path, "index"))
if txt_size > 0.5 * 1024**3 and os.path.exists(os.path.join(self.path, "index.h5")):
source = "h5"
source = "txt"
e = time.time()
bmt.print_rank("build_index_{} from {}, using {:.2f}s".format(source, self.path, e - s))
def _build_index_txt(self):
with open(os.path.join(self.path, "index"), "r") as fin:
self.bounds = [int(line) for line in fin]
self.nlines = len(self.bounds)
def safe_read(self, i_or_s, offset, size):
for retry in itertools.count():
# destroy the file identifier to avoid pressure on alluxio
# buffering=0 to avoid overhead during and open()
with open(os.path.join(self.path, "data.jsonl"), "rb", buffering=0) as fin:
raw =
return raw
except OSError as e:
if retry >= self.max_retry:
raise OSError(f"reach maximum #retry: {retry}, the file system is broken.")
f"retry loading {self.path}:{i_or_s} in {self.retry_sleep} seconds due to error: '{repr(e)}'"
except ValueError as e:
# reading error during python io, skip
safe_print(f"skipping {self.path}:{i_or_s} due to error: '{repr(e)}'")
return None
def __repr__(self):
return (
f"IndexedDataset(path={self.path}, max_retry={self.max_retry}, retry_sleep={self.retry_sleep}) "
f"with {len(self)} entries."
def __len__(self):
return len(self.bounds) - 1
def bound_idx(self, key, strict=False):
# bound index within the standard range: [0, len(self))
# useful for tracing buggy entries
if strict and not (-len(self) <= key < len(self)):
raise IndexError(f"Index {key} out of range for '{self.path}'")
key = min(max(-len(self), key), len(self)) # bound key within [-len(self), len(self)]
key = key if key > 0 else key % len(self) # remap negative id to positive ones
return key
def __getitem__(self, key):
# supports list-like slicing and indexing. strided slicing is not currently supported.
# ok: self[1], self[-1], self[1:3], self[-10:-5], self[-10:-5:1], self[:5]
# not ok: self[-10:-5:2], self[:100:3]
if isinstance(key, slice):
if not (key.step == 1 or key.step is None):
raise ValueError(f"slice step should be 1 or None, not {key.step}")
start = self.bound_idx(0 if key.start is None else key.start)
stop = max(self.bound_idx(len(self) if key.stop is None else key.stop), start)
if stop == start:
# early returning empty slice
return list()
offset, size = self.bounds[start], self.bounds[stop] - self.bounds[start]
raw = self.safe_read(key, offset, size)
if raw is None:
return None
return [
raw[s - offset : e - offset]
for s, e in zip(self.bounds[start:stop], self.bounds[start + 1 : stop + 1])
elif isinstance(key, int):
key = self.bound_idx(key, strict=True)
offset, size = self.bounds[key], self.bounds[key + 1] - self.bounds[key]
raw = self.safe_read(key, offset, size)
return raw
raise TypeError(f"indices must be integers or slices, not {type(key)}")
class PrefetchDecodeDataset(IndexedDataset):
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
# Add safe decoding in iterator
def __init__(self, *args, decode=json_decode, allow_repeat=False, **kargs):
super().__init__(*args, **kargs)
self.decode = decode
self.allow_repeat = allow_repeat
def safe_decode(self, i, raw):
if raw is None:
return None
return self.decode(raw)
except Exception as e:
safe_print(f"Skip decoding {self.path}:{i} due to error '{e}', raw bytes:\n{raw}")
return None
def __getitem__(self, key):
raw = super().__getitem__(key)
if raw is None:
return None
# key should be either a slice or an integer as checked in IndexedDataset
if isinstance(key, slice):
return [self.safe_decode(i, r) for i, r in zip(range(key.start, key.stop), raw)]
return self.safe_decode(key, raw)
def loader(self, q, lid, keys, stop, used=None):
# concurrent prefetching worker
if used is None:
used = BitSet()
for key in keys:
if stop.is_set():
# key is either a slice or an integer index
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
unused = bitset_diff(set(index), used)
if not unused:
# skip used slice / item
if not q.empty():
# avoid breaking the distributed file system with large io load
time.sleep(random.random() * 2)
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
raw = super().__getitem__(key)
if raw is None:
# filter used data
items = [(i, s) for i, s in zip(index, raw if len(index) > 1 else [raw]) if i in unused]
for item in items:
# signaling the end of iteration to the main thread
def _iterate(self, key_groups, nprefetch=1000, used=None):
# helper function for concurrent prefetching
q = queue.Queue(maxsize=nprefetch)
stop = threading.Event()
alive = set()
for lid, keys in enumerate(key_groups):
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop, used), daemon=True)
while True:
item = q.get(block=False)
except queue.Empty:
if not alive:
# no alive loader, thus no item will be put in the queue
# new item will be put later, wait for a while
if isinstance(item, StopIteration):
i, raw = item
data = self.safe_decode(i, raw)
if data is None:
yield i, data
# ask daemon loaders to stop
def iterate(self, nthreads=3, prefetch_sample=100, used=None, process_group=None):
world_size, rank, nworkers, worker_id = concurrent_info(process_group)
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
raise ValueError(
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
f"please constrain either "
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
r = Range(0, len(self), 1)
# split index among multi-gpu workers
r = r.subrange(split=rank, nsplits=world_size)
# split index among multi-process dataloader workers
r = r.subrange(split=worker_id, nsplits=nworkers)
# split index among multi-threaded loaders
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
return self._iterate(id_groups, nprefetch=prefetch_sample, used=used)
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=500, used=None):
world_size, rank, nworkers, worker_id = concurrent_info()
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
if not self.allow_repeat:
raise ValueError(
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
f"please constrain either "
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
duplicated_factor = math.ceil(nloaders / len(self))
# In this case, slice size is 1
r = Range(0, len(self), 1)
# split index among grouped multi-gpu workers
r = r.subrange(split=rank // duplicated_factor, nsplits=math.ceil(world_size / duplicated_factor))
# # split index among multi-threaded loaders
r = r.subrange(split=worker_id, nsplits=nworkers)
nslices = int(math.ceil(len(self) / slice_size))
if nslices < nloaders:
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
slice_size = len(self) // nloaders
# we only iteratre through start ids as they uniquely mark each slice
r = Range(0, len(self), slice_size)
# split index among multi-gpu workers
r = r.subrange(split=rank, nsplits=world_size)
# split index among multi-process dataloader workers
r = r.subrange(split=worker_id, nsplits=nworkers)
# split index among multi-threaded loaders
slice_groups = [
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
return self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size, used=used)
class IndexedDatasetBuilder:
def __init__(self, path, overwrite=False):
self.path = path
self.index_path = os.path.join(self.path, "index.h5")
self.index_path_txt = os.path.join(self.path, "index")
self.data_path = os.path.join(self.path, "data.jsonl")
if not overwrite:
assert not os.path.exists(self.data_path)
assert not os.path.exists(self.index_path)
assert not os.path.exists(self.index_path_txt)
self.fout = None
self.bounds = []
self.offset = 0
def __enter__(self):
os.makedirs(self.path, exist_ok=True)
self.fout = open(self.data_path, "wb")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with h5py.File(os.path.join(self.index_path), "w") as hf:
hf.create_dataset("index", data=self.bounds)
with open(self.index_path_txt, "w") as fout_txt:
for s in self.bounds:
def put(self, data: dict):
s = json_encode(data) + b"\n"
self.offset += len(s)
if __name__ == "__main__":
with IndexedDatasetBuilder("swear", overwrite=True) as builder:
for d in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
dataset = IndexedDataset("swear")
for i in range(10):
print(dataset[random.randint(0, len(dataset) - 1)])