forked from jiuyuan/CPM-9G-8B
815 lines
28 KiB
Python
815 lines
28 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The OpenBMB team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import bisect
|
|
import io
|
|
import json
|
|
import os
|
|
import random
|
|
import string
|
|
import struct
|
|
import time
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Set
|
|
|
|
import bmtrain as bmt
|
|
import torch
|
|
|
|
from .serializer import PickleSerializer
|
|
from .serializer import Serializer
|
|
|
|
|
|
def _random_string():
|
|
return "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
|
|
|
|
|
_DEFAULT_BLOCK_SIZE = 16 << 20
|
|
|
|
|
|
class FileInfo:
|
|
def __init__(
|
|
self,
|
|
file_name: str = "",
|
|
block_begin: int = 0,
|
|
block_end: int = 0,
|
|
nbytes: int = 0,
|
|
nlines: int = 0,
|
|
mask: bool = False,
|
|
block_size: int = _DEFAULT_BLOCK_SIZE,
|
|
) -> None:
|
|
self.file_name = file_name
|
|
self.block_begin = block_begin
|
|
self.block_end = block_end
|
|
self.nbytes = nbytes
|
|
self.nlines = nlines
|
|
self.mask = mask
|
|
self.block_size = block_size
|
|
|
|
def state_dict(self):
|
|
return {
|
|
"file_name": self.file_name,
|
|
"block_begin": self.block_begin,
|
|
"block_end": self.block_end,
|
|
"nbytes": self.nbytes,
|
|
"nlines": self.nlines,
|
|
"mask": self.mask,
|
|
"block_size": self.block_size,
|
|
}
|
|
|
|
def load_state_dict(self, d):
|
|
self.file_name = d["file_name"]
|
|
self.block_begin = d["block_begin"]
|
|
self.block_end = d["block_end"]
|
|
self.nbytes = d["nbytes"]
|
|
self.nlines = d["nlines"]
|
|
self.mask = d["mask"]
|
|
self.block_size = d["block_size"]
|
|
|
|
def dumps(self) -> str:
|
|
return json.dumps(self.state_dict())
|
|
|
|
def loads(self, data: str) -> "FileInfo":
|
|
self.load_state_dict(json.loads(data))
|
|
return self
|
|
|
|
def dump(self, fp: io.TextIOWrapper) -> "FileInfo":
|
|
fp.write(self.dumps())
|
|
return self
|
|
|
|
def load(self, fp: io.TextIOWrapper) -> "FileInfo":
|
|
self.loads(fp.read())
|
|
return self
|
|
|
|
|
|
def _read_info_list(meta_path: str) -> List[FileInfo]:
|
|
info: List[FileInfo] = []
|
|
while True:
|
|
try:
|
|
with open(meta_path, "r", encoding="utf-8") as f:
|
|
for line in f.readlines():
|
|
line = line.strip()
|
|
if len(line) > 0:
|
|
info.append(FileInfo().loads(line))
|
|
return info
|
|
except Exception as e:
|
|
print(
|
|
"Error: reading info list in _read_info_list!, meta_path={path}, err={err}".format(
|
|
path=meta_path, err=str(e)
|
|
)
|
|
)
|
|
time.sleep(10)
|
|
|
|
|
|
def _write_info_list(meta_path: str, info: List[FileInfo]):
|
|
base_path = os.path.dirname(meta_path)
|
|
random_fname = os.path.join(base_path, ".meta.bin.%s" % _random_string())
|
|
while True:
|
|
try:
|
|
with open(random_fname, "w", encoding="utf-8") as f:
|
|
for v in info:
|
|
f.write(v.dumps() + "\n")
|
|
os.rename(random_fname, meta_path)
|
|
return
|
|
except Exception:
|
|
print("Error: writing info list!")
|
|
time.sleep(10)
|
|
|
|
|
|
def _filtered_range(begin: int, end: int, rank: int, world_size: int, filter_set: Optional[Set[int]] = None):
|
|
begin = begin + (rank + (world_size - (begin % world_size))) % world_size
|
|
|
|
if filter_set is not None:
|
|
return [i for i in range(begin, end, world_size) if i in filter_set]
|
|
else:
|
|
return [i for i in range(begin, end, world_size)]
|
|
|
|
|
|
class SafeFile:
|
|
def __init__(self, fname, mode):
|
|
self.fname = None
|
|
self.mode = None
|
|
self._fp = None
|
|
self.open_file(fname, mode)
|
|
|
|
def read(self, size=-1):
|
|
if self._fp is None:
|
|
raise RuntimeError("Dataset is closed")
|
|
try:
|
|
res = self._fp.read(size)
|
|
self.offset = self._fp.tell()
|
|
return res
|
|
except Exception as e:
|
|
print("Error: reading blocks in {}! err {}".format(self.fname, str(e)))
|
|
self.close()
|
|
self.open_file(self.fname, self.mode, self.offset)
|
|
return self.read(size)
|
|
|
|
def tell(self):
|
|
if self._fp is None:
|
|
raise RuntimeError("Dataset is closed")
|
|
try:
|
|
res = self._fp.tell()
|
|
self.offset = res
|
|
return res
|
|
except Exception as e:
|
|
print("Error: telling blocks in {}! err {}".format(self.fname, str(e)))
|
|
self.close()
|
|
self.open_file(self.fname, self.mode, self.offset)
|
|
return self.tell()
|
|
|
|
def seek(self, offset, whence=0):
|
|
if self._fp is None:
|
|
raise RuntimeError("Dataset is closed")
|
|
try:
|
|
res = self._fp.seek(offset, whence)
|
|
self.offset = self._fp.tell()
|
|
return res
|
|
except Exception as e:
|
|
print("Error: seeking blocks in {}! err {}".format(self.fname, str(e)))
|
|
self.close()
|
|
self.open_file(self.fname, self.mode, self.offset)
|
|
return self.seek(offset, whence)
|
|
|
|
def close(self):
|
|
if self._fp is not None:
|
|
try:
|
|
self._fp.close()
|
|
except Exception as e:
|
|
print("Error: closing blocks in {}! err {}".format(self.fname, str(e)))
|
|
self._fp = None
|
|
|
|
def open_file(self, fname, mode, offset=None):
|
|
if not os.path.exists(fname):
|
|
print("Dataset {} does not exist".format(fname))
|
|
self.close()
|
|
time.sleep(20)
|
|
self.open_file(fname, mode, offset)
|
|
try:
|
|
self.fname = fname
|
|
self.mode = mode
|
|
self._fp = open(fname, mode)
|
|
if offset is not None:
|
|
self._fp.seek(offset, io.SEEK_SET)
|
|
self.offset = self._fp.tell()
|
|
except Exception as e:
|
|
print("Error: opening blocks in {}! err {}".format(self.fname, str(e)))
|
|
self.close()
|
|
time.sleep(20)
|
|
self.open_file(fname, mode, offset)
|
|
|
|
|
|
class DistributedDataset:
|
|
"""Open dataset in readonly mode.
|
|
|
|
`DistributeDataset` is used to read datasets in a distributed manner.
|
|
Data in this dataset will be distributed evenly in blocks to each worker in the `distributed communicator`.
|
|
|
|
**Note** When all data has been read, reading dataset again will revert back to the first data.
|
|
|
|
Args:
|
|
path (str): Path to dataset.
|
|
rank (int): Rank in distributed communicator. See: bmtrain.rank()
|
|
world_size (int): Total workers in distributed communicator. See: bmtrain.world_size()
|
|
block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB
|
|
|
|
Example:
|
|
>>> dataset = DistributedDataset("/path/to/dataset")
|
|
>>> for i in range(10):
|
|
>>> dataset.read()
|
|
""" # noqa: E501
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
rank: int = 0,
|
|
world_size: int = 1,
|
|
serializer: Optional[Serializer] = None,
|
|
max_repeat_times: Optional[int] = None,
|
|
shuffle: bool = True,
|
|
) -> None:
|
|
# config
|
|
self._path = path
|
|
self._rank = rank
|
|
self._world_size = world_size
|
|
self._max_repeat_times = max_repeat_times
|
|
self._repeat_times = 0
|
|
self._shuffle = shuffle
|
|
|
|
if serializer is None:
|
|
serializer = PickleSerializer()
|
|
self.serializer = serializer
|
|
|
|
# dataset meta
|
|
self._unused_block: List[int] = []
|
|
self._unused_block_offset: List[int] = []
|
|
self._file_info: List[FileInfo] = []
|
|
self._file_ends: List[int] = []
|
|
self._total_blocks = 0
|
|
self._nbytes = 0
|
|
self._nlines = 0
|
|
|
|
# states
|
|
self._curr_block = None
|
|
self._fp = None
|
|
|
|
# cache
|
|
self._last_mod_time = 0
|
|
self._curr_fname = None
|
|
|
|
self._update_states(fast_skip=False)
|
|
self._repeat_times += 1
|
|
|
|
def _update_states(self, fast_skip: bool = True):
|
|
meta_path = os.path.join(self._path, "meta.bin")
|
|
|
|
while True:
|
|
try:
|
|
mod_time = os.stat(meta_path).st_mtime
|
|
break
|
|
except Exception as e:
|
|
print(
|
|
"Error: reading info list in DistributedDataset._update_states, "
|
|
"meta_path={path}, err={err}!".format(path=meta_path, err=str(e))
|
|
)
|
|
time.sleep(10)
|
|
|
|
if self._last_mod_time < mod_time:
|
|
# file changed
|
|
self._last_mod_time = mod_time
|
|
else:
|
|
if fast_skip:
|
|
return
|
|
|
|
info: List[FileInfo] = []
|
|
if os.path.exists(meta_path):
|
|
info = _read_info_list(meta_path)
|
|
old_len = len(self._file_info)
|
|
if old_len > len(info):
|
|
raise RuntimeError("Dataset meta file: changed unexpectly")
|
|
|
|
mask_changed = False
|
|
for i in range(old_len):
|
|
if self._file_info[i].file_name != info[i].file_name:
|
|
raise RuntimeError("Dataset meta file: changed unexpectly")
|
|
if self._file_info[i].block_begin != info[i].block_begin:
|
|
raise RuntimeError("Dataset meta file: changed unexpectly")
|
|
if self._file_info[i].block_end != info[i].block_end:
|
|
raise RuntimeError("Dataset meta file: changed unexpectly")
|
|
if self._file_info[i].mask != info[i].mask:
|
|
mask_changed = True
|
|
|
|
if info[0].block_begin != 0:
|
|
raise RuntimeError("Dataset meta file: block error (0)")
|
|
for i in range(len(info) - 1):
|
|
if info[i].block_end != info[i + 1].block_begin:
|
|
raise RuntimeError("Dataset meta file: block error (%d)" % (i + 1))
|
|
|
|
if (old_len == len(info) and not mask_changed) and fast_skip:
|
|
# fast skip
|
|
return
|
|
|
|
if len(info) > 0:
|
|
total_blocks = info[-1].block_end
|
|
self._nbytes = 0
|
|
self._nlines = 0
|
|
for v in info:
|
|
self._nbytes += v.nbytes
|
|
self._nlines += v.nlines
|
|
else:
|
|
total_blocks = 0
|
|
self._nbytes = 0
|
|
self._nlines = 0
|
|
|
|
if total_blocks > 0:
|
|
unused_block_set = set(self._unused_block)
|
|
nw_unused_block: List[int] = []
|
|
for i in range(len(info)):
|
|
v = info[i]
|
|
if not v.mask:
|
|
if i < old_len:
|
|
nw_unused_block.extend(
|
|
_filtered_range(
|
|
v.block_begin,
|
|
v.block_end,
|
|
self._rank,
|
|
self._world_size,
|
|
unused_block_set,
|
|
)
|
|
)
|
|
else:
|
|
nw_unused_block.extend(
|
|
_filtered_range(v.block_begin, v.block_end, self._rank, self._world_size)
|
|
)
|
|
|
|
# re-shuffle unused blocks
|
|
if self._shuffle:
|
|
random.shuffle(nw_unused_block)
|
|
|
|
offset_dict = {block: offset for block, offset in zip(self._unused_block, self._unused_block_offset)}
|
|
nw_unused_block_offset = [offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block]
|
|
self._unused_block = nw_unused_block
|
|
self._unused_block_offset = nw_unused_block_offset
|
|
|
|
self._file_ends = []
|
|
for v in info:
|
|
self._file_ends.append(v.block_end)
|
|
else:
|
|
self._unused_block = []
|
|
self._unused_block_offset = []
|
|
self._file_ends = []
|
|
|
|
self._total_blocks = total_blocks
|
|
self._file_info = info
|
|
assert len(self._unused_block) == len(self._unused_block_offset)
|
|
assert len(self._file_ends) == len(self._file_info)
|
|
|
|
def _mask_file(self, f: FileInfo):
|
|
nw_unused_block: List[int] = []
|
|
nw_unused_block_offset: List[int] = []
|
|
for block_id, block_offset in zip(self._unused_block, self._unused_block_offset):
|
|
if block_id < f.block_begin or block_id >= f.block_end:
|
|
nw_unused_block.append(block_id)
|
|
nw_unused_block_offset.append(block_offset)
|
|
self._unused_block = nw_unused_block
|
|
self._unused_block_offset = nw_unused_block_offset
|
|
|
|
def _get_block_file(self, block_id: int):
|
|
# find block in which file
|
|
file_idx = bisect.bisect_right(self._file_ends, block_id)
|
|
return self._file_info[file_idx]
|
|
|
|
def _prepare_new_epoch(self):
|
|
if self._max_repeat_times is not None:
|
|
if self._repeat_times >= self._max_repeat_times:
|
|
raise EOFError("End of dataset")
|
|
nw_unused_block: List[int] = []
|
|
for v in self._file_info:
|
|
if not v.mask:
|
|
nw_unused_block.extend(_filtered_range(v.block_begin, v.block_end, self._rank, self._world_size))
|
|
if self._shuffle:
|
|
random.shuffle(nw_unused_block)
|
|
self._unused_block = nw_unused_block
|
|
self._unused_block_offset = [0 for _ in nw_unused_block]
|
|
self._repeat_times += 1
|
|
|
|
def _get_next_block(self):
|
|
self._update_states()
|
|
if len(self._unused_block) == 0:
|
|
self._prepare_new_epoch()
|
|
if len(self._unused_block) == 0:
|
|
raise RuntimeError("Empty dataset {}".format(self._path))
|
|
|
|
mn_block: int = self._unused_block.pop()
|
|
mn_block_offset: int = self._unused_block_offset.pop()
|
|
return mn_block, mn_block_offset
|
|
|
|
def _state_dict(self):
|
|
self._update_states()
|
|
num_unused_block = len(self._unused_block)
|
|
if (self._fp is not None) and (self._curr_block is not None):
|
|
curr_block = self._curr_block
|
|
curr_f = self._get_block_file(curr_block)
|
|
inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
|
|
else:
|
|
curr_block = -1
|
|
inblock_offset = 0
|
|
|
|
return {
|
|
"states": torch.tensor(self._unused_block, dtype=torch.long, device="cpu"),
|
|
"offset": torch.tensor(self._unused_block_offset, dtype=torch.long, device="cpu"),
|
|
"block": torch.tensor(
|
|
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
|
dtype=torch.long,
|
|
device="cpu",
|
|
),
|
|
}
|
|
|
|
def state_dict(self):
|
|
"""Returns a state dict representing the read states of the dataset.
|
|
|
|
Example:
|
|
>>> state = dataset.state_dict()
|
|
>>> dataset.load_state_dict(state)
|
|
"""
|
|
self._update_states()
|
|
num_unused_block = len(self._unused_block)
|
|
if (self._fp is not None) and (self._curr_block is not None):
|
|
curr_block = self._curr_block
|
|
curr_f = self._get_block_file(curr_block)
|
|
inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
|
|
else:
|
|
curr_block = -1
|
|
inblock_offset = 0
|
|
|
|
with torch.no_grad():
|
|
if self._world_size > 1:
|
|
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
|
max_unused_blocks = (
|
|
bmt.distributed.all_reduce(gpu_num_unused_block, op="max", comm=bmt.config["tp_zero_comm"])
|
|
.cpu()
|
|
.item()
|
|
)
|
|
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
|
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
|
|
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
|
|
gpu_offset[:num_unused_block] = torch.tensor(self._unused_block_offset, dtype=torch.long).cuda()
|
|
gpu_block = torch.tensor(
|
|
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
|
dtype=torch.long,
|
|
).cuda()
|
|
global_states = bmt.distributed.all_gather(
|
|
gpu_states, comm=bmt.config["tp_zero_comm"]
|
|
).cpu() # (world_size, max_unused_blocks)
|
|
global_offset = bmt.distributed.all_gather(
|
|
gpu_offset, comm=bmt.config["tp_zero_comm"]
|
|
).cpu() # (world_size, max_unused_blocks)
|
|
global_block = bmt.distributed.all_gather(
|
|
gpu_block, comm=bmt.config["tp_zero_comm"]
|
|
).cpu() # (world_size, 4)
|
|
return {"states": global_states, "offset": global_offset, "block": global_block}
|
|
else:
|
|
return {
|
|
"states": torch.tensor([self._unused_block], dtype=torch.long, device="cpu"),
|
|
"offset": torch.tensor([self._unused_block_offset], dtype=torch.long, device="cpu"),
|
|
"block": torch.tensor(
|
|
[[curr_block, inblock_offset, num_unused_block, self._repeat_times]],
|
|
dtype=torch.long,
|
|
device="cpu",
|
|
),
|
|
}
|
|
|
|
def load_state_dict(self, state, strict: bool = True):
|
|
"""Load dataset state.
|
|
|
|
Args:
|
|
state (dict): dataset state dict.
|
|
strict (bool): If `strict` is True, world size needs to be the same as when exported.
|
|
|
|
Example:
|
|
>>> state = dataset.state_dict()
|
|
>>>
|
|
"""
|
|
block_states: torch.LongTensor = state["states"]
|
|
block_info: torch.LongTensor = state["block"]
|
|
if "offset" not in state:
|
|
block_offset: torch.LongTensor = torch.zeros_like(block_states).long()
|
|
else:
|
|
block_offset: torch.LongTensor = state["offset"]
|
|
|
|
if block_states.size(0) != self._world_size:
|
|
if strict:
|
|
raise ValueError("world_size changed (%d -> %d)" % (state["block"].size(0), self._world_size))
|
|
else:
|
|
self._curr_block = None
|
|
self._fp = None
|
|
self._curr_fname = None
|
|
self._repeat_times = int(block_info[0, 3].item())
|
|
offset_dict = {}
|
|
for i in range(block_states.size(0)):
|
|
for block, offset in zip(block_states[i].tolist(), block_offset[i].tolist()):
|
|
offset_dict[block] = offset
|
|
|
|
# re-shuffle unused blocks
|
|
nw_unused_block: List[int] = []
|
|
|
|
for i in range(block_states.size(0)):
|
|
_, _, num_unused_blocks, _ = block_info[i].tolist()
|
|
nw_unused_block.extend(
|
|
[
|
|
block_id
|
|
for block_id in block_states[i, :num_unused_blocks].tolist()
|
|
if block_id % self._world_size == self._rank
|
|
]
|
|
)
|
|
|
|
for i in range(block_states.size(0)):
|
|
curr_block, inblock_offset, num_unused_blocks, _ = block_info[i].tolist()
|
|
if curr_block < 0:
|
|
continue
|
|
if curr_block % self._world_size == self._rank:
|
|
nw_unused_block.append(curr_block)
|
|
offset_dict[curr_block] = inblock_offset
|
|
|
|
curr_block, inblock_offset, _, self._repeat_times = block_info[self._rank].tolist()
|
|
# if self._shuffle:
|
|
# random.shuffle(nw_unused_block)
|
|
nw_unused_block_offset = [
|
|
offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block
|
|
]
|
|
self._unused_block = nw_unused_block
|
|
self._unused_block_offset = nw_unused_block_offset
|
|
|
|
else:
|
|
curr_block, inblock_offset, num_unused_blocks, self._repeat_times = block_info[self._rank].tolist()
|
|
if curr_block == -1:
|
|
self._curr_block = None
|
|
self._unused_block = []
|
|
self.unused_block_offset = []
|
|
else:
|
|
while True:
|
|
try:
|
|
self._curr_block = curr_block
|
|
f_info = self._get_block_file(self._curr_block)
|
|
self._open_file(
|
|
f_info.file_name,
|
|
(self._curr_block - f_info.block_begin) * f_info.block_size + inblock_offset,
|
|
)
|
|
self._unused_block = block_states[self._rank, :num_unused_blocks].tolist()
|
|
self.unused_block_offset = block_offset[self._rank, :num_unused_blocks].tolist()
|
|
break
|
|
except Exception:
|
|
print("Error: reading blocks in {}".format(f_info.file_name))
|
|
time.sleep(10)
|
|
# end
|
|
self._update_states()
|
|
|
|
def _get_file_path(self, fname):
|
|
return os.path.join(self._path, fname)
|
|
|
|
def _open_file(self, fname, offset):
|
|
if self._curr_fname != fname:
|
|
if self._fp is not None:
|
|
self._fp.close()
|
|
self._curr_fname = None
|
|
# self._fp = open(self._get_file_path(fname), "rb")
|
|
self._fp = SafeFile(self._get_file_path(fname), "rb")
|
|
self._curr_fname = fname
|
|
else:
|
|
assert self._fp is not None, "Unexpected error"
|
|
self._fp.seek(offset, io.SEEK_SET) # move to block
|
|
|
|
def read(self):
|
|
"""Read a piece of data from dataset.
|
|
|
|
Workers in different ranks will read different data.
|
|
"""
|
|
if self._curr_block is None:
|
|
next_block_id, next_block_offset = self._get_next_block()
|
|
f_info = self._get_block_file(next_block_id)
|
|
try:
|
|
self._open_file(
|
|
f_info.file_name, (next_block_id - f_info.block_begin) * f_info.block_size + next_block_offset
|
|
)
|
|
self._curr_block = next_block_id
|
|
except FileNotFoundError:
|
|
print("ERR: reading again!")
|
|
self._mask_file(f_info)
|
|
return self.read() # read again
|
|
|
|
if self._fp is None:
|
|
raise RuntimeError("Dataset is not initialized")
|
|
|
|
MAGIC = self._fp.read(1)
|
|
if MAGIC == b"\x1F":
|
|
# correct
|
|
size = struct.unpack("I", self._fp.read(4))[0]
|
|
data = self._fp.read(size)
|
|
return self.serializer.deserialize(data)
|
|
elif MAGIC == b"\x00":
|
|
# end of block
|
|
self._curr_block = None
|
|
return self.read() # read next block
|
|
else:
|
|
raise ValueError("Invalid magic header")
|
|
|
|
@property
|
|
def nbytes(self):
|
|
return self._nbytes
|
|
|
|
|
|
class SimpleDataset(DistributedDataset):
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
serializer: Optional[Serializer] = None,
|
|
shuffle: bool = True,
|
|
) -> None:
|
|
super().__init__(
|
|
path,
|
|
0,
|
|
1,
|
|
serializer=serializer,
|
|
max_repeat_times=1,
|
|
shuffle=shuffle,
|
|
)
|
|
|
|
def __iter__(self):
|
|
while True:
|
|
try:
|
|
data = self.read()
|
|
except EOFError:
|
|
self._repeat_times = 0
|
|
break
|
|
yield data
|
|
|
|
def __len__(self):
|
|
return self._nlines
|
|
|
|
def get_bytes(self):
|
|
return self._nbytes
|
|
|
|
|
|
class DatasetWriter:
|
|
def __init__(self, fname: str, block_size: int, serializer: Optional[Serializer] = None):
|
|
self._fname = fname
|
|
self._block_size = block_size
|
|
self._fp = open(self._fname, "wb")
|
|
self._inblock_offset = 0
|
|
|
|
self._nbytes = 0
|
|
self._nlines = 0
|
|
self._nblocks = 1
|
|
|
|
if serializer is None:
|
|
serializer = PickleSerializer()
|
|
self.serializer = serializer
|
|
|
|
def write(self, data):
|
|
"""Write a piece of data into dataset.
|
|
|
|
Args:
|
|
data (Any): Serialization will be done using pickle.
|
|
|
|
Example:
|
|
>>> writer.write( "anything you want" )
|
|
|
|
"""
|
|
byte_data = self.serializer.serialize(data)
|
|
byte_data = struct.pack("I", len(byte_data)) + byte_data
|
|
if self._inblock_offset + 2 + len(byte_data) > self._block_size:
|
|
self._fp.write(b"\x00" * (self._block_size - self._inblock_offset)) # fill the remaining space with 0
|
|
self._inblock_offset = 0
|
|
self._nblocks += 1
|
|
# we go to the next block
|
|
if self._inblock_offset + 2 + len(byte_data) > self._block_size:
|
|
raise ValueError("data is larger than block size")
|
|
|
|
self._nbytes += len(byte_data)
|
|
self._nlines += 1
|
|
|
|
self._inblock_offset += 1 + len(byte_data)
|
|
self._fp.write(b"\x1F")
|
|
self._fp.write(byte_data)
|
|
|
|
@property
|
|
def nbytes(self):
|
|
return self._nbytes
|
|
|
|
@property
|
|
def nblocks(self):
|
|
return self._nblocks
|
|
|
|
@property
|
|
def nlines(self):
|
|
return self._nlines
|
|
|
|
def close(self):
|
|
if not self._fp.closed:
|
|
self._fp.write(b"\x00" * (self._block_size - self._inblock_offset))
|
|
self._fp.close()
|
|
|
|
|
|
class DatasetBuilder:
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
dbname: str,
|
|
block_size=_DEFAULT_BLOCK_SIZE,
|
|
serializer: Optional[Serializer] = None,
|
|
) -> None:
|
|
self._block_size = block_size
|
|
self._path = path
|
|
self._dbname = dbname
|
|
if serializer is None:
|
|
serializer = PickleSerializer()
|
|
self.serializer = serializer
|
|
|
|
if not os.path.exists(self._path):
|
|
os.makedirs(self._path)
|
|
|
|
meta_path = os.path.join(self._path, "meta.bin")
|
|
|
|
info: List[FileInfo] = []
|
|
if os.path.exists(meta_path):
|
|
info = _read_info_list(meta_path)
|
|
|
|
for v in info:
|
|
if v.file_name == dbname:
|
|
raise ValueError("Dataset name exists")
|
|
|
|
self._db_path = os.path.join(self._path, self._dbname)
|
|
if os.path.exists(self._db_path):
|
|
raise ValueError("File exists `%s`" % self._db_path)
|
|
|
|
def __enter__(self):
|
|
self._writer = DatasetWriter(self._db_path, self._block_size, self.serializer)
|
|
return self._writer
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
if self._writer is None:
|
|
raise RuntimeError("Unexpected call to __exit__")
|
|
|
|
self._writer.close()
|
|
if exc_type is not None:
|
|
print("Error while writing file")
|
|
if os.path.exists(self._db_path):
|
|
os.unlink(self._db_path)
|
|
else:
|
|
meta_path = os.path.join(self._path, "meta.bin")
|
|
info: List[FileInfo] = []
|
|
if os.path.exists(meta_path):
|
|
info = _read_info_list(meta_path)
|
|
last_block = 0
|
|
if len(info) > 0:
|
|
last_block = info[-1].block_end
|
|
info.append(
|
|
FileInfo(
|
|
self._dbname,
|
|
last_block,
|
|
last_block + self._writer.nblocks,
|
|
self._writer.nbytes,
|
|
self._writer.nlines,
|
|
False,
|
|
self._block_size,
|
|
)
|
|
)
|
|
|
|
# atomic write to meta file
|
|
_write_info_list(meta_path, info)
|
|
|
|
self._writer = None
|
|
|
|
|
|
def build_dataset(
|
|
path: str,
|
|
dbname: str,
|
|
block_size: int = _DEFAULT_BLOCK_SIZE,
|
|
serializer: Optional[Serializer] = None,
|
|
):
|
|
"""Open the dataset in write mode and returns a writer.
|
|
|
|
Args:
|
|
path (str): Path to dataset.
|
|
dbname (str): The name of the file to which the data will be written. The `dbname` needs to be unique in this `dataset`.
|
|
block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB
|
|
|
|
Example:
|
|
>>> with build_dataset("/path/to/dataset", "data_part_1") as writer:
|
|
>>> for i in range(10):
|
|
>>> writer.write( { "anything you want" } )
|
|
""" # noqa: E501
|
|
return DatasetBuilder(path, dbname, block_size=block_size, serializer=serializer)
|