import bisect import io import json import os import random import string import struct import time from typing import List from typing import Optional from typing import Set import bmtrain as bmt import torch from .serializer import PickleSerializer from .serializer import Serializer def _random_string(): return "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) _DEFAULT_BLOCK_SIZE = 16 << 20 class FileInfo: def __init__( self, file_name: str = "", block_begin: int = 0, block_end: int = 0, nbytes: int = 0, nlines: int = 0, mask: bool = False, block_size: int = _DEFAULT_BLOCK_SIZE, ) -> None: self.file_name = file_name self.block_begin = block_begin self.block_end = block_end self.nbytes = nbytes self.nlines = nlines self.mask = mask self.block_size = block_size def state_dict(self): return { "file_name": self.file_name, "block_begin": self.block_begin, "block_end": self.block_end, "nbytes": self.nbytes, "nlines": self.nlines, "mask": self.mask, "block_size": self.block_size, } def load_state_dict(self, d): self.file_name = d["file_name"] self.block_begin = d["block_begin"] self.block_end = d["block_end"] self.nbytes = d["nbytes"] self.nlines = d["nlines"] self.mask = d["mask"] self.block_size = d["block_size"] def dumps(self) -> str: return json.dumps(self.state_dict()) def loads(self, data: str) -> "FileInfo": self.load_state_dict(json.loads(data)) return self def dump(self, fp: io.TextIOWrapper) -> "FileInfo": fp.write(self.dumps()) return self def load(self, fp: io.TextIOWrapper) -> "FileInfo": self.loads(fp.read()) return self def _read_info_list(meta_path: str) -> List[FileInfo]: info: List[FileInfo] = [] while True: try: with open(meta_path, "r", encoding="utf-8") as f: for line in f.readlines(): line = line.strip() if len(line) > 0: info.append(FileInfo().loads(line)) return info except Exception as e: print( "Error: reading info list in _read_info_list!, meta_path={path}, err={err}".format( path=meta_path, err=str(e) ) ) time.sleep(10) def _write_info_list(meta_path: str, info: List[FileInfo]): base_path = os.path.dirname(meta_path) random_fname = os.path.join(base_path, ".meta.bin.%s" % _random_string()) while True: try: with open(random_fname, "w", encoding="utf-8") as f: for v in info: f.write(v.dumps() + "\n") os.rename(random_fname, meta_path) return except Exception: print("Error: writing info list!") time.sleep(10) def _filtered_range(begin: int, end: int, rank: int, world_size: int, filter_set: Optional[Set[int]] = None): begin = begin + (rank + (world_size - (begin % world_size))) % world_size if filter_set is not None: return [i for i in range(begin, end, world_size) if i in filter_set] else: return [i for i in range(begin, end, world_size)] class SafeFile: def __init__(self, fname, mode): self.fname = None self.mode = None self._fp = None self.open_file(fname, mode) def read(self, size=-1): if self._fp is None: raise RuntimeError("Dataset is closed") try: res = self._fp.read(size) self.offset = self._fp.tell() return res except Exception as e: print("Error: reading blocks in {}! err {}".format(self.fname, str(e))) self.close() self.open_file(self.fname, self.mode, self.offset) return self.read(size) def tell(self): if self._fp is None: raise RuntimeError("Dataset is closed") try: res = self._fp.tell() self.offset = res return res except Exception as e: print("Error: telling blocks in {}! err {}".format(self.fname, str(e))) self.close() self.open_file(self.fname, self.mode, self.offset) return self.tell() def seek(self, offset, whence=0): if self._fp is None: raise RuntimeError("Dataset is closed") try: res = self._fp.seek(offset, whence) self.offset = self._fp.tell() return res except Exception as e: print("Error: seeking blocks in {}! err {}".format(self.fname, str(e))) self.close() self.open_file(self.fname, self.mode, self.offset) return self.seek(offset, whence) def close(self): if self._fp is not None: try: self._fp.close() except Exception as e: print("Error: closing blocks in {}! err {}".format(self.fname, str(e))) self._fp = None def open_file(self, fname, mode, offset=None): if not os.path.exists(fname): print("Dataset {} does not exist".format(fname)) self.close() time.sleep(20) self.open_file(fname, mode, offset) try: self.fname = fname self.mode = mode self._fp = open(fname, mode) if offset is not None: self._fp.seek(offset, io.SEEK_SET) self.offset = self._fp.tell() except Exception as e: print("Error: opening blocks in {}! err {}".format(self.fname, str(e))) self.close() time.sleep(20) self.open_file(fname, mode, offset) class DistributedDataset: """Open dataset in readonly mode. `DistributeDataset` is used to read datasets in a distributed manner. Data in this dataset will be distributed evenly in blocks to each worker in the `distributed communicator`. **Note** When all data has been read, reading dataset again will revert back to the first data. Args: path (str): Path to dataset. rank (int): Rank in distributed communicator. See: bmtrain.rank() world_size (int): Total workers in distributed communicator. See: bmtrain.world_size() block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB Example: >>> dataset = DistributedDataset("/path/to/dataset") >>> for i in range(10): >>> dataset.read() """ # noqa: E501 def __init__( self, path: str, rank: int = 0, world_size: int = 1, serializer: Optional[Serializer] = None, max_repeat_times: Optional[int] = None, shuffle: bool = True, ) -> None: # config self._path = path self._rank = rank self._world_size = world_size self._max_repeat_times = max_repeat_times self._repeat_times = 0 self._shuffle = shuffle if serializer is None: serializer = PickleSerializer() self.serializer = serializer # dataset meta self._unused_block: List[int] = [] self._unused_block_offset: List[int] = [] self._file_info: List[FileInfo] = [] self._file_ends: List[int] = [] self._total_blocks = 0 self._nbytes = 0 self._nlines = 0 # states self._curr_block = None self._fp = None # cache self._last_mod_time = 0 self._curr_fname = None self._update_states(fast_skip=False) self._repeat_times += 1 def _update_states(self, fast_skip: bool = True): meta_path = os.path.join(self._path, "meta.bin") while True: try: mod_time = os.stat(meta_path).st_mtime break except Exception as e: print( "Error: reading info list in DistributedDataset._update_states, " "meta_path={path}, err={err}!".format(path=meta_path, err=str(e)) ) time.sleep(10) if self._last_mod_time < mod_time: # file changed self._last_mod_time = mod_time else: if fast_skip: return info: List[FileInfo] = [] if os.path.exists(meta_path): info = _read_info_list(meta_path) old_len = len(self._file_info) if old_len > len(info): raise RuntimeError("Dataset meta file: changed unexpectly") mask_changed = False for i in range(old_len): if self._file_info[i].file_name != info[i].file_name: raise RuntimeError("Dataset meta file: changed unexpectly") if self._file_info[i].block_begin != info[i].block_begin: raise RuntimeError("Dataset meta file: changed unexpectly") if self._file_info[i].block_end != info[i].block_end: raise RuntimeError("Dataset meta file: changed unexpectly") if self._file_info[i].mask != info[i].mask: mask_changed = True if info[0].block_begin != 0: raise RuntimeError("Dataset meta file: block error (0)") for i in range(len(info) - 1): if info[i].block_end != info[i + 1].block_begin: raise RuntimeError("Dataset meta file: block error (%d)" % (i + 1)) if (old_len == len(info) and not mask_changed) and fast_skip: # fast skip return if len(info) > 0: total_blocks = info[-1].block_end self._nbytes = 0 self._nlines = 0 for v in info: self._nbytes += v.nbytes self._nlines += v.nlines else: total_blocks = 0 self._nbytes = 0 self._nlines = 0 if total_blocks > 0: unused_block_set = set(self._unused_block) nw_unused_block: List[int] = [] for i in range(len(info)): v = info[i] if not v.mask: if i < old_len: nw_unused_block.extend( _filtered_range( v.block_begin, v.block_end, self._rank, self._world_size, unused_block_set, ) ) else: nw_unused_block.extend( _filtered_range(v.block_begin, v.block_end, self._rank, self._world_size) ) # re-shuffle unused blocks if self._shuffle: random.shuffle(nw_unused_block) offset_dict = {block: offset for block, offset in zip(self._unused_block, self._unused_block_offset)} nw_unused_block_offset = [offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block] self._unused_block = nw_unused_block self._unused_block_offset = nw_unused_block_offset self._file_ends = [] for v in info: self._file_ends.append(v.block_end) else: self._unused_block = [] self._unused_block_offset = [] self._file_ends = [] self._total_blocks = total_blocks self._file_info = info assert len(self._unused_block) == len(self._unused_block_offset) assert len(self._file_ends) == len(self._file_info) def _mask_file(self, f: FileInfo): nw_unused_block: List[int] = [] nw_unused_block_offset: List[int] = [] for block_id, block_offset in zip(self._unused_block, self._unused_block_offset): if block_id < f.block_begin or block_id >= f.block_end: nw_unused_block.append(block_id) nw_unused_block_offset.append(block_offset) self._unused_block = nw_unused_block self._unused_block_offset = nw_unused_block_offset def _get_block_file(self, block_id: int): # find block in which file file_idx = bisect.bisect_right(self._file_ends, block_id) return self._file_info[file_idx] def _prepare_new_epoch(self): if self._max_repeat_times is not None: if self._repeat_times >= self._max_repeat_times: raise EOFError("End of dataset") nw_unused_block: List[int] = [] for v in self._file_info: if not v.mask: nw_unused_block.extend(_filtered_range(v.block_begin, v.block_end, self._rank, self._world_size)) if self._shuffle: random.shuffle(nw_unused_block) self._unused_block = nw_unused_block self._unused_block_offset = [0 for _ in nw_unused_block] self._repeat_times += 1 def _get_next_block(self): self._update_states() if len(self._unused_block) == 0: self._prepare_new_epoch() if len(self._unused_block) == 0: raise RuntimeError("Empty dataset {}".format(self._path)) mn_block: int = self._unused_block.pop() mn_block_offset: int = self._unused_block_offset.pop() return mn_block, mn_block_offset def _state_dict(self): self._update_states() num_unused_block = len(self._unused_block) if (self._fp is not None) and (self._curr_block is not None): curr_block = self._curr_block curr_f = self._get_block_file(curr_block) inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size else: curr_block = -1 inblock_offset = 0 return { "states": torch.tensor(self._unused_block, dtype=torch.long, device="cpu"), "offset": torch.tensor(self._unused_block_offset, dtype=torch.long, device="cpu"), "block": torch.tensor( [curr_block, inblock_offset, num_unused_block, self._repeat_times], dtype=torch.long, device="cpu", ), } def state_dict(self): """Returns a state dict representing the read states of the dataset. Example: >>> state = dataset.state_dict() >>> dataset.load_state_dict(state) """ self._update_states() num_unused_block = len(self._unused_block) if (self._fp is not None) and (self._curr_block is not None): curr_block = self._curr_block curr_f = self._get_block_file(curr_block) inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size else: curr_block = -1 inblock_offset = 0 with torch.no_grad(): if self._world_size > 1: gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda() max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item() gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda() gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda() gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda() gpu_offset[:num_unused_block] = torch.tensor(self._unused_block_offset, dtype=torch.long).cuda() gpu_block = torch.tensor( [curr_block, inblock_offset, num_unused_block, self._repeat_times], dtype=torch.long, ).cuda() global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks) global_offset = bmt.distributed.all_gather(gpu_offset).cpu() # (world_size, max_unused_blocks) global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4) return {"states": global_states, "offset": global_offset, "block": global_block} else: return { "states": torch.tensor([self._unused_block], dtype=torch.long, device="cpu"), "offset": torch.tensor([self._unused_block_offset], dtype=torch.long, device="cpu"), "block": torch.tensor( [[curr_block, inblock_offset, num_unused_block, self._repeat_times]], dtype=torch.long, device="cpu", ), } def load_state_dict(self, state, strict: bool = True): """Load dataset state. Args: state (dict): dataset state dict. strict (bool): If `strict` is True, world size needs to be the same as when exported. Example: >>> state = dataset.state_dict() >>> """ block_states: torch.LongTensor = state["states"] block_info: torch.LongTensor = state["block"] if "offset" not in state: block_offset: torch.LongTensor = torch.zeros_like(block_states).long() else: block_offset: torch.LongTensor = state["offset"] if block_states.size(0) != self._world_size: if strict: raise ValueError("world_size changed (%d -> %d)" % (state["block"].size(0), self._world_size)) else: self._curr_block = None self._fp = None self._curr_fname = None self._repeat_times = int(block_info[0, 3].item()) offset_dict = {} for i in range(block_states.size(0)): for block, offset in zip(block_states[i].tolist(), block_offset[i].tolist()): offset_dict[block] = offset # re-shuffle unused blocks nw_unused_block: List[int] = [] for i in range(block_states.size(0)): _, _, num_unused_blocks, _ = block_info[i].tolist() nw_unused_block.extend( [ block_id for block_id in block_states[i, :num_unused_blocks].tolist() if block_id % self._world_size == self._rank ] ) for i in range(block_states.size(0)): curr_block, inblock_offset, num_unused_blocks, _ = block_info[i].tolist() if curr_block < 0: continue if curr_block % self._world_size == self._rank: nw_unused_block.append(curr_block) offset_dict[curr_block] = inblock_offset curr_block, inblock_offset, _, self._repeat_times = block_info[self._rank].tolist() # if self._shuffle: # random.shuffle(nw_unused_block) nw_unused_block_offset = [ offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block ] self._unused_block = nw_unused_block self._unused_block_offset = nw_unused_block_offset else: curr_block, inblock_offset, num_unused_blocks, self._repeat_times = block_info[self._rank].tolist() if curr_block == -1: self._curr_block = None self._unused_block = [] self.unused_block_offset = [] else: while True: try: self._curr_block = curr_block f_info = self._get_block_file(self._curr_block) self._open_file( f_info.file_name, (self._curr_block - f_info.block_begin) * f_info.block_size + inblock_offset, ) self._unused_block = block_states[self._rank, :num_unused_blocks].tolist() self.unused_block_offset = block_offset[self._rank, :num_unused_blocks].tolist() break except Exception: print("Error: reading blocks in {}".format(f_info.file_name)) time.sleep(10) # end self._update_states() def _get_file_path(self, fname): return os.path.join(self._path, fname) def _open_file(self, fname, offset): if self._curr_fname != fname: if self._fp is not None: self._fp.close() self._curr_fname = None # self._fp = open(self._get_file_path(fname), "rb") self._fp = SafeFile(self._get_file_path(fname), "rb") self._curr_fname = fname else: assert self._fp is not None, "Unexpected error" self._fp.seek(offset, io.SEEK_SET) # move to block def read(self): """Read a piece of data from dataset. Workers in different ranks will read different data. """ if self._curr_block is None: next_block_id, next_block_offset = self._get_next_block() f_info = self._get_block_file(next_block_id) try: self._open_file( f_info.file_name, (next_block_id - f_info.block_begin) * f_info.block_size + next_block_offset ) self._curr_block = next_block_id except FileNotFoundError: print("ERR: reading again!") self._mask_file(f_info) return self.read() # read again if self._fp is None: raise RuntimeError("Dataset is not initialized") MAGIC = self._fp.read(1) if MAGIC == b"\x1F": # correct size = struct.unpack("I", self._fp.read(4))[0] data = self._fp.read(size) return self.serializer.deserialize(data) elif MAGIC == b"\x00": # end of block self._curr_block = None return self.read() # read next block else: raise ValueError("Invalid magic header") @property def nbytes(self): return self._nbytes class SimpleDataset(DistributedDataset): def __init__( self, path: str, serializer: Optional[Serializer] = None, shuffle: bool = True, ) -> None: super().__init__( path, 0, 1, serializer=serializer, max_repeat_times=1, shuffle=shuffle, ) def __iter__(self): while True: try: data = self.read() except EOFError: self._repeat_times = 0 break yield data def __len__(self): return self._nlines def get_bytes(self): return self._nbytes class DatasetWriter: def __init__(self, fname: str, block_size: int, serializer: Optional[Serializer] = None): self._fname = fname self._block_size = block_size self._fp = open(self._fname, "wb") self._inblock_offset = 0 self._nbytes = 0 self._nlines = 0 self._nblocks = 1 if serializer is None: serializer = PickleSerializer() self.serializer = serializer def write(self, data): """Write a piece of data into dataset. Args: data (Any): Serialization will be done using pickle. Example: >>> writer.write( "anything you want" ) """ byte_data = self.serializer.serialize(data) byte_data = struct.pack("I", len(byte_data)) + byte_data if self._inblock_offset + 2 + len(byte_data) > self._block_size: self._fp.write(b"\x00" * (self._block_size - self._inblock_offset)) # fill the remaining space with 0 self._inblock_offset = 0 self._nblocks += 1 # we go to the next block if self._inblock_offset + 2 + len(byte_data) > self._block_size: raise ValueError("data is larger than block size") self._nbytes += len(byte_data) self._nlines += 1 self._inblock_offset += 1 + len(byte_data) self._fp.write(b"\x1F") self._fp.write(byte_data) @property def nbytes(self): return self._nbytes @property def nblocks(self): return self._nblocks @property def nlines(self): return self._nlines def close(self): if not self._fp.closed: self._fp.write(b"\x00" * (self._block_size - self._inblock_offset)) self._fp.close() class DatasetBuilder: def __init__( self, path: str, dbname: str, block_size=_DEFAULT_BLOCK_SIZE, serializer: Optional[Serializer] = None, ) -> None: self._block_size = block_size self._path = path self._dbname = dbname if serializer is None: serializer = PickleSerializer() self.serializer = serializer if not os.path.exists(self._path): os.makedirs(self._path) meta_path = os.path.join(self._path, "meta.bin") info: List[FileInfo] = [] if os.path.exists(meta_path): info = _read_info_list(meta_path) for v in info: if v.file_name == dbname: raise ValueError("Dataset name exists") self._db_path = os.path.join(self._path, self._dbname) if os.path.exists(self._db_path): raise ValueError("File exists `%s`" % self._db_path) def __enter__(self): self._writer = DatasetWriter(self._db_path, self._block_size, self.serializer) return self._writer def __exit__(self, exc_type, exc_value, exc_traceback): if self._writer is None: raise RuntimeError("Unexpected call to __exit__") self._writer.close() if exc_type is not None: print("Error while writing file") if os.path.exists(self._db_path): os.unlink(self._db_path) else: meta_path = os.path.join(self._path, "meta.bin") info: List[FileInfo] = [] if os.path.exists(meta_path): info = _read_info_list(meta_path) last_block = 0 if len(info) > 0: last_block = info[-1].block_end info.append( FileInfo( self._dbname, last_block, last_block + self._writer.nblocks, self._writer.nbytes, self._writer.nlines, False, self._block_size, ) ) # atomic write to meta file _write_info_list(meta_path, info) self._writer = None def build_dataset( path: str, dbname: str, block_size: int = _DEFAULT_BLOCK_SIZE, serializer: Optional[Serializer] = None, ): """Open the dataset in write mode and returns a writer. Args: path (str): Path to dataset. dbname (str): The name of the file to which the data will be written. The `dbname` needs to be unique in this `dataset`. block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB Example: >>> with build_dataset("/path/to/dataset", "data_part_1") as writer: >>> for i in range(10): >>> writer.write( { "anything you want" } ) """ # noqa: E501 return DatasetBuilder(path, dbname, block_size=block_size, serializer=serializer)