CPM-9G-8B/9G-Train/cpm/dataset/utils.py

422 lines
13 KiB
Python
Raw Normal View History

2024-02-27 14:33:33 +08:00
import json
import math
import os
import random
import shutil
import struct
from queue import Queue
from threading import Thread
from typing import Iterable
from typing import List
from typing import Optional
import torch
from ..utils.log import logger
from .distributed_dataset import _DEFAULT_BLOCK_SIZE
from .distributed_dataset import _random_string
from .distributed_dataset import _read_info_list
from .distributed_dataset import _write_info_list
from .distributed_dataset import build_dataset
from .distributed_dataset import FileInfo
from .distributed_dataset import SimpleDataset
from .serializer import RawSerializer
try:
from tqdm import tqdm
support_tqdm = True
except ModuleNotFoundError:
support_tqdm = False
_DEFAULT_SHUFFLE_BUCKET_SIZE = 1 << 30
def shuffle_dataset(
path_src: str,
path_tgt: str,
block_size: int = _DEFAULT_BLOCK_SIZE,
bucket_size: int = _DEFAULT_SHUFFLE_BUCKET_SIZE,
progress_bar: bool = False,
output_name: Optional[str] = None,
):
"""Shuffle one distributed datataset, write results to another dataset.
Args:
path_str (str): path to source dataset
path_tgt (str): path to write results
block_size (int): dataset block size (default: 16MB)
bucket_size (int): shuffle algorithm bucket size (default: 1GB)
progress_bar (bool): show progress bar
Example:
>>> shuffle_dataset("/path/to/source", "/path/to/output")
"""
if progress_bar and not support_tqdm:
raise RuntimeError("Requires `tqdm` to enable progress bar.")
ds = SimpleDataset(path_src, serializer=RawSerializer())
num_buckets = (ds.nbytes + bucket_size - 1) // bucket_size
tmp_files = [os.path.join(path_src, ".tmp.%s" % _random_string()) for _ in range(num_buckets)]
try:
# Step 1: write to bucket randomly
f_tmp = [open(fname, "wb") for fname in tmp_files]
try:
iterator = ds
if progress_bar:
iterator = tqdm(ds, desc="Shuffle step 1/2")
for data in iterator:
bucket_id = int(random.random() * num_buckets)
len_data = len(data)
f_tmp[bucket_id].write(struct.pack("I", len_data) + data)
finally:
# close all files
for fp in f_tmp:
if not fp.closed:
fp.close()
f_tmp = []
# Step 2: shuffle inside bucket
if output_name is None:
output_name = "%s.shuffle" % _random_string()
with build_dataset(
path_tgt,
output_name,
block_size=block_size,
serializer=RawSerializer(),
) as writer:
iterator = tmp_files
if progress_bar:
iterator = tqdm(tmp_files, desc="Shuffle step 2/2")
for fname in iterator:
fp = open(fname, "rb")
data_in_bucket = []
while True:
try:
raw_data = fp.read(4)
if len(raw_data) == 0:
# EOF
break
len_data = struct.unpack("I", raw_data)[0]
data_in_bucket.append(fp.read(len_data))
except EOFError:
break
random.shuffle(data_in_bucket)
for data in data_in_bucket:
writer.write(data)
fp.close()
os.unlink(fname)
finally:
# cleanup
for fname in tmp_files:
if os.path.exists(fname):
os.unlink(fname)
def compact_dataset(path: str):
"""Compact the dataset, removes blocks which the files were deleted.
**Note** This may affect the existing dataset state dict.
Args:
path (str): path to dataset
Example:
>>> compact_dataset("/path/to/dataset")
"""
meta_path = os.path.join(path, "meta.bin")
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
else:
raise ValueError("Dataset not exists")
nw_info: List[FileInfo] = []
curr_block = 0
for v in info:
if not os.path.exists(v.file_name):
# file is deleted
pass
else:
num_file_block = v.block_end - v.block_begin
nw_info.append(
FileInfo(
v.file_name,
curr_block,
curr_block + num_file_block,
v.nbytes,
v.nlines,
v.mask,
v.block_size,
)
)
curr_block += num_file_block
_write_info_list(meta_path, nw_info)
def mask_dataset(path: str, dbname: str, mask: bool = True):
"""Mask one file in dataset. Blocks in masked datasets won't be read later.
Args:
path (str): path to dataset
dbname (str): file name in this dataset which you want to mask
mask (bool): True for mask, False for unmask
Example:
>>> mask_dataset("/path/to/dataset", "data_part_1", mask=True)
"""
meta_path = os.path.join(path, "meta.bin")
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
else:
raise ValueError("Dataset not exists")
for v in info:
if v.file_name == dbname:
v.mask = mask
_write_info_list(meta_path, info)
def merge_dataset(dst: str, src: str):
meta_path_src = os.path.join(src, "meta.bin")
meta_path_dst = os.path.join(dst, "meta.bin")
info_src: List[FileInfo] = []
if os.path.exists(meta_path_src):
info_src = _read_info_list(meta_path_src)
else:
raise ValueError("Dataset not exists")
info_dst: List[FileInfo] = []
if os.path.exists(meta_path_dst):
info_dst = _read_info_list(meta_path_dst)
else:
raise ValueError("Dataset not exists")
curr_block = 0
nw_info: List[FileInfo] = []
for v in info_dst:
num_file_block = v.block_end - v.block_begin
nw_info.append(
FileInfo(
v.file_name,
curr_block,
curr_block + num_file_block,
v.nbytes,
v.nlines,
v.mask,
v.block_size,
)
)
curr_block += num_file_block
for v in info_src:
num_file_block = v.block_end - v.block_begin
dst_db_name = os.path.join(dst, v.file_name)
nw_fname = v.file_name
if os.path.exists(dst_db_name):
idx = 0
while os.path.exists(dst_db_name + "_{}".format(idx)):
idx += 1
dst_db_name = dst_db_name + "_{}".format(idx)
nw_fname = nw_fname + "_{}".format(idx)
shutil.copy(os.path.join(src, v.file_name), dst_db_name)
nw_info.append(
FileInfo(
nw_fname,
curr_block,
curr_block + num_file_block,
v.nbytes,
v.nlines,
v.mask,
v.block_size,
)
)
curr_block += num_file_block
_write_info_list(meta_path_dst, nw_info)
def to_cpm(src_data, dst_path, dst_name):
if not os.path.exists(dst_path):
os.makedirs(dst_path)
logger.info(f"src_data: {src_data}")
logger.info(f"dst_path: {dst_path}")
logger.info(f"dst_name: {dst_name}")
tmp_dst_path = dst_path.rstrip("/") + "_tmp"
if not os.path.exists(tmp_dst_path):
os.makedirs(tmp_dst_path)
logger.info(f"write binary into: {tmp_dst_path}")
with build_dataset(tmp_dst_path, dst_name) as dataset:
if os.path.isdir(src_data):
filenames = [os.path.join(src_data, name) for name in os.listdir(src_data)]
else:
filenames = [src_data]
n_filenames = len(filenames)
for idx, filename in enumerate(filenames):
logger.info(f"deal: [{n_filenames} -> {idx}] {filename}")
if not os.path.exists(filename):
logger.error(f"not exist: {filename}")
continue
with open(filename, "r", encoding="utf-8") as fin:
for line in fin:
line = line.strip()
dataset.write(json.loads(line))
logger.info(f"shuffle binary data from {tmp_dst_path} to {dst_path}")
shuffle_dataset(tmp_dst_path, dst_path, progress_bar=True, output_name=dst_name)
if os.path.exists(tmp_dst_path):
shutil.rmtree(tmp_dst_path)
def random_range(start, stop=None, step=None):
"""
Generator of non-repeated random permutation with the same inteface of python
`range`. Obtained from https://stackoverflow.com/a/53551417
The random.shuffle(list) and random.sample(list, len(list)) require
materialize the lists, which result in a long initalization period.
"""
if stop is None:
start, stop = 0, start
if step is None:
step = 1
# Use a mapping to convert a standard range into the desired range.
mapping = lambda i: (i * step) + start
# Compute the number of numbers in this range.
maximum = int(math.ceil((stop - start) / step))
if maximum == 0:
# early return with empty range
yield from ()
return
# Seed range with a random integer.
value = random.randint(0, maximum)
# Construct an offset, multiplier, and modulus for a linear
# congruential generator. These generators are cyclic and
# non-repeating when they maintain the properties:
#
# 1) "modulus" and "offset" are relatively prime.
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
# Pick a random odd-valued offset.
offset = random.randint(0, maximum) * 2 + 1
# Pick a multiplier 1 greater than a multiple of 4.
multiplier = 4 * (maximum // 4) + 1
# Pick a modulus just big enough to generate all numbers (power of 2).
modulus = int(2 ** math.ceil(math.log2(maximum)))
# Track how many random numbers have been returned.
found = 0
while found < maximum:
# If this is a valid value, yield it in generator fashion.
if value < maximum:
found += 1
yield mapping(value)
# Calculate the next value in the sequence.
value = (value * multiplier + offset) % modulus
class Range(object):
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
self.step = step
def __repr__(self):
return f"Range({self.start}, {self.stop}, {self.step})"
def iterate(self):
yield from range(self.start, self.stop, self.step)
def list(self):
return list(range(self.start, self.stop, self.step))
def subrange(self, split, nsplits):
# strided spliting range params
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
def random_iterate(self):
yield from random_range(self.start, self.stop, self.step)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
data = self.data
self.preload()
return data
def __iter__(self):
return self
class ThreadedPrefetcher(Thread):
def __init__(self, iterable, prefetch=10):
"""
Wrap around a data iterator to shield io latency with a daemon thread.
"""
super(ThreadedPrefetcher, self).__init__()
self.queue = Queue(maxsize=prefetch)
self.iterable = iterable
self.daemon = True
self.start()
def run(self):
try:
for data in self.iterable:
self.queue.put(data)
except Exception as exception:
self.queue.put(exception)
finally:
self.queue.put(StopIteration())
def __next__(self):
item = self.queue.get()
if isinstance(item, Exception):
raise item
else:
return item
def __iter__(self):
return self