forked from jiuyuan/CPM-9G-8B
422 lines
13 KiB
Python
422 lines
13 KiB
Python
|
import json
|
||
|
import math
|
||
|
import os
|
||
|
import random
|
||
|
import shutil
|
||
|
import struct
|
||
|
from queue import Queue
|
||
|
from threading import Thread
|
||
|
from typing import Iterable
|
||
|
from typing import List
|
||
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from ..utils.log import logger
|
||
|
from .distributed_dataset import _DEFAULT_BLOCK_SIZE
|
||
|
from .distributed_dataset import _random_string
|
||
|
from .distributed_dataset import _read_info_list
|
||
|
from .distributed_dataset import _write_info_list
|
||
|
from .distributed_dataset import build_dataset
|
||
|
from .distributed_dataset import FileInfo
|
||
|
from .distributed_dataset import SimpleDataset
|
||
|
from .serializer import RawSerializer
|
||
|
|
||
|
try:
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
support_tqdm = True
|
||
|
except ModuleNotFoundError:
|
||
|
support_tqdm = False
|
||
|
|
||
|
_DEFAULT_SHUFFLE_BUCKET_SIZE = 1 << 30
|
||
|
|
||
|
|
||
|
def shuffle_dataset(
|
||
|
path_src: str,
|
||
|
path_tgt: str,
|
||
|
block_size: int = _DEFAULT_BLOCK_SIZE,
|
||
|
bucket_size: int = _DEFAULT_SHUFFLE_BUCKET_SIZE,
|
||
|
progress_bar: bool = False,
|
||
|
output_name: Optional[str] = None,
|
||
|
):
|
||
|
"""Shuffle one distributed datataset, write results to another dataset.
|
||
|
|
||
|
Args:
|
||
|
path_str (str): path to source dataset
|
||
|
path_tgt (str): path to write results
|
||
|
block_size (int): dataset block size (default: 16MB)
|
||
|
bucket_size (int): shuffle algorithm bucket size (default: 1GB)
|
||
|
progress_bar (bool): show progress bar
|
||
|
|
||
|
Example:
|
||
|
>>> shuffle_dataset("/path/to/source", "/path/to/output")
|
||
|
"""
|
||
|
|
||
|
if progress_bar and not support_tqdm:
|
||
|
raise RuntimeError("Requires `tqdm` to enable progress bar.")
|
||
|
|
||
|
ds = SimpleDataset(path_src, serializer=RawSerializer())
|
||
|
num_buckets = (ds.nbytes + bucket_size - 1) // bucket_size
|
||
|
|
||
|
tmp_files = [os.path.join(path_src, ".tmp.%s" % _random_string()) for _ in range(num_buckets)]
|
||
|
|
||
|
try:
|
||
|
# Step 1: write to bucket randomly
|
||
|
f_tmp = [open(fname, "wb") for fname in tmp_files]
|
||
|
try:
|
||
|
iterator = ds
|
||
|
if progress_bar:
|
||
|
iterator = tqdm(ds, desc="Shuffle step 1/2")
|
||
|
for data in iterator:
|
||
|
bucket_id = int(random.random() * num_buckets)
|
||
|
len_data = len(data)
|
||
|
f_tmp[bucket_id].write(struct.pack("I", len_data) + data)
|
||
|
finally:
|
||
|
# close all files
|
||
|
for fp in f_tmp:
|
||
|
if not fp.closed:
|
||
|
fp.close()
|
||
|
f_tmp = []
|
||
|
|
||
|
# Step 2: shuffle inside bucket
|
||
|
if output_name is None:
|
||
|
output_name = "%s.shuffle" % _random_string()
|
||
|
with build_dataset(
|
||
|
path_tgt,
|
||
|
output_name,
|
||
|
block_size=block_size,
|
||
|
serializer=RawSerializer(),
|
||
|
) as writer:
|
||
|
iterator = tmp_files
|
||
|
if progress_bar:
|
||
|
iterator = tqdm(tmp_files, desc="Shuffle step 2/2")
|
||
|
|
||
|
for fname in iterator:
|
||
|
fp = open(fname, "rb")
|
||
|
data_in_bucket = []
|
||
|
while True:
|
||
|
try:
|
||
|
raw_data = fp.read(4)
|
||
|
if len(raw_data) == 0:
|
||
|
# EOF
|
||
|
break
|
||
|
len_data = struct.unpack("I", raw_data)[0]
|
||
|
data_in_bucket.append(fp.read(len_data))
|
||
|
except EOFError:
|
||
|
break
|
||
|
random.shuffle(data_in_bucket)
|
||
|
for data in data_in_bucket:
|
||
|
writer.write(data)
|
||
|
fp.close()
|
||
|
os.unlink(fname)
|
||
|
finally:
|
||
|
# cleanup
|
||
|
for fname in tmp_files:
|
||
|
if os.path.exists(fname):
|
||
|
os.unlink(fname)
|
||
|
|
||
|
|
||
|
def compact_dataset(path: str):
|
||
|
"""Compact the dataset, removes blocks which the files were deleted.
|
||
|
|
||
|
**Note** This may affect the existing dataset state dict.
|
||
|
|
||
|
Args:
|
||
|
path (str): path to dataset
|
||
|
|
||
|
Example:
|
||
|
>>> compact_dataset("/path/to/dataset")
|
||
|
|
||
|
"""
|
||
|
|
||
|
meta_path = os.path.join(path, "meta.bin")
|
||
|
|
||
|
info: List[FileInfo] = []
|
||
|
if os.path.exists(meta_path):
|
||
|
info = _read_info_list(meta_path)
|
||
|
else:
|
||
|
raise ValueError("Dataset not exists")
|
||
|
|
||
|
nw_info: List[FileInfo] = []
|
||
|
curr_block = 0
|
||
|
for v in info:
|
||
|
if not os.path.exists(v.file_name):
|
||
|
# file is deleted
|
||
|
pass
|
||
|
else:
|
||
|
num_file_block = v.block_end - v.block_begin
|
||
|
nw_info.append(
|
||
|
FileInfo(
|
||
|
v.file_name,
|
||
|
curr_block,
|
||
|
curr_block + num_file_block,
|
||
|
v.nbytes,
|
||
|
v.nlines,
|
||
|
v.mask,
|
||
|
v.block_size,
|
||
|
)
|
||
|
)
|
||
|
curr_block += num_file_block
|
||
|
|
||
|
_write_info_list(meta_path, nw_info)
|
||
|
|
||
|
|
||
|
def mask_dataset(path: str, dbname: str, mask: bool = True):
|
||
|
"""Mask one file in dataset. Blocks in masked datasets won't be read later.
|
||
|
|
||
|
Args:
|
||
|
path (str): path to dataset
|
||
|
dbname (str): file name in this dataset which you want to mask
|
||
|
mask (bool): True for mask, False for unmask
|
||
|
|
||
|
Example:
|
||
|
>>> mask_dataset("/path/to/dataset", "data_part_1", mask=True)
|
||
|
|
||
|
"""
|
||
|
|
||
|
meta_path = os.path.join(path, "meta.bin")
|
||
|
|
||
|
info: List[FileInfo] = []
|
||
|
if os.path.exists(meta_path):
|
||
|
info = _read_info_list(meta_path)
|
||
|
else:
|
||
|
raise ValueError("Dataset not exists")
|
||
|
|
||
|
for v in info:
|
||
|
if v.file_name == dbname:
|
||
|
v.mask = mask
|
||
|
_write_info_list(meta_path, info)
|
||
|
|
||
|
|
||
|
def merge_dataset(dst: str, src: str):
|
||
|
meta_path_src = os.path.join(src, "meta.bin")
|
||
|
meta_path_dst = os.path.join(dst, "meta.bin")
|
||
|
|
||
|
info_src: List[FileInfo] = []
|
||
|
if os.path.exists(meta_path_src):
|
||
|
info_src = _read_info_list(meta_path_src)
|
||
|
else:
|
||
|
raise ValueError("Dataset not exists")
|
||
|
|
||
|
info_dst: List[FileInfo] = []
|
||
|
if os.path.exists(meta_path_dst):
|
||
|
info_dst = _read_info_list(meta_path_dst)
|
||
|
else:
|
||
|
raise ValueError("Dataset not exists")
|
||
|
|
||
|
curr_block = 0
|
||
|
nw_info: List[FileInfo] = []
|
||
|
for v in info_dst:
|
||
|
num_file_block = v.block_end - v.block_begin
|
||
|
nw_info.append(
|
||
|
FileInfo(
|
||
|
v.file_name,
|
||
|
curr_block,
|
||
|
curr_block + num_file_block,
|
||
|
v.nbytes,
|
||
|
v.nlines,
|
||
|
v.mask,
|
||
|
v.block_size,
|
||
|
)
|
||
|
)
|
||
|
curr_block += num_file_block
|
||
|
|
||
|
for v in info_src:
|
||
|
num_file_block = v.block_end - v.block_begin
|
||
|
|
||
|
dst_db_name = os.path.join(dst, v.file_name)
|
||
|
nw_fname = v.file_name
|
||
|
if os.path.exists(dst_db_name):
|
||
|
idx = 0
|
||
|
while os.path.exists(dst_db_name + "_{}".format(idx)):
|
||
|
idx += 1
|
||
|
dst_db_name = dst_db_name + "_{}".format(idx)
|
||
|
nw_fname = nw_fname + "_{}".format(idx)
|
||
|
|
||
|
shutil.copy(os.path.join(src, v.file_name), dst_db_name)
|
||
|
nw_info.append(
|
||
|
FileInfo(
|
||
|
nw_fname,
|
||
|
curr_block,
|
||
|
curr_block + num_file_block,
|
||
|
v.nbytes,
|
||
|
v.nlines,
|
||
|
v.mask,
|
||
|
v.block_size,
|
||
|
)
|
||
|
)
|
||
|
curr_block += num_file_block
|
||
|
|
||
|
_write_info_list(meta_path_dst, nw_info)
|
||
|
|
||
|
|
||
|
def to_cpm(src_data, dst_path, dst_name):
|
||
|
if not os.path.exists(dst_path):
|
||
|
os.makedirs(dst_path)
|
||
|
|
||
|
logger.info(f"src_data: {src_data}")
|
||
|
logger.info(f"dst_path: {dst_path}")
|
||
|
logger.info(f"dst_name: {dst_name}")
|
||
|
|
||
|
tmp_dst_path = dst_path.rstrip("/") + "_tmp"
|
||
|
if not os.path.exists(tmp_dst_path):
|
||
|
os.makedirs(tmp_dst_path)
|
||
|
|
||
|
logger.info(f"write binary into: {tmp_dst_path}")
|
||
|
with build_dataset(tmp_dst_path, dst_name) as dataset:
|
||
|
if os.path.isdir(src_data):
|
||
|
filenames = [os.path.join(src_data, name) for name in os.listdir(src_data)]
|
||
|
else:
|
||
|
filenames = [src_data]
|
||
|
|
||
|
n_filenames = len(filenames)
|
||
|
for idx, filename in enumerate(filenames):
|
||
|
logger.info(f"deal: [{n_filenames} -> {idx}] {filename}")
|
||
|
if not os.path.exists(filename):
|
||
|
logger.error(f"not exist: {filename}")
|
||
|
continue
|
||
|
|
||
|
with open(filename, "r", encoding="utf-8") as fin:
|
||
|
for line in fin:
|
||
|
line = line.strip()
|
||
|
dataset.write(json.loads(line))
|
||
|
|
||
|
logger.info(f"shuffle binary data from {tmp_dst_path} to {dst_path}")
|
||
|
shuffle_dataset(tmp_dst_path, dst_path, progress_bar=True, output_name=dst_name)
|
||
|
|
||
|
if os.path.exists(tmp_dst_path):
|
||
|
shutil.rmtree(tmp_dst_path)
|
||
|
|
||
|
|
||
|
def random_range(start, stop=None, step=None):
|
||
|
"""
|
||
|
Generator of non-repeated random permutation with the same inteface of python
|
||
|
`range`. Obtained from https://stackoverflow.com/a/53551417
|
||
|
The random.shuffle(list) and random.sample(list, len(list)) require
|
||
|
materialize the lists, which result in a long initalization period.
|
||
|
"""
|
||
|
if stop is None:
|
||
|
start, stop = 0, start
|
||
|
if step is None:
|
||
|
step = 1
|
||
|
# Use a mapping to convert a standard range into the desired range.
|
||
|
mapping = lambda i: (i * step) + start
|
||
|
# Compute the number of numbers in this range.
|
||
|
maximum = int(math.ceil((stop - start) / step))
|
||
|
if maximum == 0:
|
||
|
# early return with empty range
|
||
|
yield from ()
|
||
|
return
|
||
|
# Seed range with a random integer.
|
||
|
value = random.randint(0, maximum)
|
||
|
# Construct an offset, multiplier, and modulus for a linear
|
||
|
# congruential generator. These generators are cyclic and
|
||
|
# non-repeating when they maintain the properties:
|
||
|
#
|
||
|
# 1) "modulus" and "offset" are relatively prime.
|
||
|
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
|
||
|
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
|
||
|
|
||
|
# Pick a random odd-valued offset.
|
||
|
offset = random.randint(0, maximum) * 2 + 1
|
||
|
# Pick a multiplier 1 greater than a multiple of 4.
|
||
|
multiplier = 4 * (maximum // 4) + 1
|
||
|
# Pick a modulus just big enough to generate all numbers (power of 2).
|
||
|
modulus = int(2 ** math.ceil(math.log2(maximum)))
|
||
|
# Track how many random numbers have been returned.
|
||
|
found = 0
|
||
|
while found < maximum:
|
||
|
# If this is a valid value, yield it in generator fashion.
|
||
|
if value < maximum:
|
||
|
found += 1
|
||
|
yield mapping(value)
|
||
|
# Calculate the next value in the sequence.
|
||
|
value = (value * multiplier + offset) % modulus
|
||
|
|
||
|
|
||
|
class Range(object):
|
||
|
def __init__(self, start, stop, step):
|
||
|
self.start = start
|
||
|
self.stop = stop
|
||
|
self.step = step
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"Range({self.start}, {self.stop}, {self.step})"
|
||
|
|
||
|
def iterate(self):
|
||
|
yield from range(self.start, self.stop, self.step)
|
||
|
|
||
|
def list(self):
|
||
|
return list(range(self.start, self.stop, self.step))
|
||
|
|
||
|
def subrange(self, split, nsplits):
|
||
|
# strided spliting range params
|
||
|
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
|
||
|
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
|
||
|
|
||
|
def random_iterate(self):
|
||
|
yield from random_range(self.start, self.stop, self.step)
|
||
|
|
||
|
|
||
|
class CudaPrefetcher(Iterable):
|
||
|
"""
|
||
|
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, loader):
|
||
|
self.loader = iter(loader)
|
||
|
self.stream = torch.cuda.Stream()
|
||
|
self.preload()
|
||
|
|
||
|
def preload(self):
|
||
|
try:
|
||
|
self.data = next(self.loader)
|
||
|
except StopIteration:
|
||
|
self.data = None
|
||
|
return
|
||
|
with torch.cuda.stream(self.stream):
|
||
|
for key in self.data.keys():
|
||
|
if isinstance(self.data[key], torch.Tensor):
|
||
|
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||
|
|
||
|
def __next__(self):
|
||
|
torch.cuda.current_stream().wait_stream(self.stream)
|
||
|
data = self.data
|
||
|
self.preload()
|
||
|
return data
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
|
||
|
class ThreadedPrefetcher(Thread):
|
||
|
def __init__(self, iterable, prefetch=10):
|
||
|
"""
|
||
|
Wrap around a data iterator to shield io latency with a daemon thread.
|
||
|
"""
|
||
|
super(ThreadedPrefetcher, self).__init__()
|
||
|
self.queue = Queue(maxsize=prefetch)
|
||
|
self.iterable = iterable
|
||
|
self.daemon = True
|
||
|
self.start()
|
||
|
|
||
|
def run(self):
|
||
|
try:
|
||
|
for data in self.iterable:
|
||
|
self.queue.put(data)
|
||
|
except Exception as exception:
|
||
|
self.queue.put(exception)
|
||
|
finally:
|
||
|
self.queue.put(StopIteration())
|
||
|
|
||
|
def __next__(self):
|
||
|
item = self.queue.get()
|
||
|
if isinstance(item, Exception):
|
||
|
raise item
|
||
|
else:
|
||
|
return item
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|