forked from jiuyuan/CPM-9G-8B
737 lines
26 KiB
Python
737 lines
26 KiB
Python
import importlib.machinery
|
|
import importlib.util
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
import random
|
|
import time
|
|
import types
|
|
from collections import OrderedDict
|
|
from queue import Empty
|
|
from typing import Any
|
|
from typing import Callable
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Set
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
import bmtrain as bmt
|
|
import numpy as np
|
|
import torch
|
|
from numpy.typing import NDArray
|
|
from typing_extensions import TypedDict
|
|
|
|
from ...dataset import DistributedDataset
|
|
from ..tokenizers import CPM9GTokenizer
|
|
|
|
|
|
class _MixedDatasetConfig(TypedDict):
|
|
weight: float
|
|
path: str
|
|
transforms: Union[List[Dict[str, Any]], str]
|
|
task_name: str
|
|
dataset_name: str
|
|
incontext_weight: List[float]
|
|
lines: int
|
|
dataset: DistributedDataset
|
|
|
|
|
|
CPM9GInputType = Union[str, Dict[str, "CPM9GInputType"]]
|
|
|
|
|
|
class _TransformFuncDict(TypedDict):
|
|
loader: importlib.machinery.SourceFileLoader
|
|
module: types.ModuleType
|
|
last_m: float
|
|
|
|
|
|
_TransformFunction = Callable[[CPM9GInputType, int, random.Random], CPM9GInputType]
|
|
|
|
|
|
class CPM9GBatch(TypedDict):
|
|
inputs: NDArray[np.int32]
|
|
length: NDArray[np.int32]
|
|
context: NDArray[np.bool_]
|
|
sample_ids: NDArray[np.int32]
|
|
spans: NDArray[np.int32]
|
|
target: NDArray[np.int32]
|
|
task_ids: NDArray[np.int32]
|
|
task_names: List[str]
|
|
raw_data: List[Any]
|
|
|
|
|
|
def convert_data_to_id(tokenizer: CPM9GTokenizer, data: Any):
|
|
input_ids = tokenizer.encode(data["input"])
|
|
output_ids = tokenizer.encode(data["output"])
|
|
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
|
|
ids = np.array(ids, dtype=np.int32)
|
|
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
|
context[: len(input_ids) + 1] = 1
|
|
return ids, context
|
|
|
|
|
|
def _dataset_identity(c: _MixedDatasetConfig):
|
|
return "{}.{}".format(c["task_name"], c["dataset_name"])
|
|
|
|
|
|
class _MixedDatasetBatchPacker:
|
|
def __init__(
|
|
self,
|
|
batch_size: int,
|
|
max_length: int,
|
|
tokenizer: CPM9GTokenizer,
|
|
unpad: bool = False,
|
|
) -> None:
|
|
self._batch_size = batch_size
|
|
self._max_length = max_length
|
|
self._clip_length = max_length
|
|
self.tokenizer = tokenizer
|
|
self._transform_func_table: Dict[str, _TransformFuncDict] = {}
|
|
self._unpad = unpad
|
|
if unpad:
|
|
self._max_length = max_length * batch_size
|
|
self._batch_size = 1
|
|
|
|
self._inputs: List[NDArray[np.int32]] = []
|
|
self._context: List[NDArray[np.int8]] = []
|
|
self._sample_ids: List[NDArray[np.int32]] = []
|
|
self._spans: List[List[int]] = []
|
|
self._task_ids: List[List[str]] = []
|
|
self._raw_data: List[List[Any]] = []
|
|
|
|
def __len__(self):
|
|
return len(self._inputs)
|
|
|
|
def apply_transform(
|
|
self,
|
|
data: CPM9GInputType,
|
|
transform: Union[Dict[str, Any], Callable[[CPM9GInputType], CPM9GInputType], None],
|
|
) -> CPM9GInputType:
|
|
if transform is None:
|
|
return data
|
|
if not isinstance(transform, dict):
|
|
return transform(data) # transform function
|
|
|
|
mapping_list: List[Tuple[str, str]] = []
|
|
|
|
def _walk_transform_dict(data: Union[Dict[str, Any], str], prefix: str = ""):
|
|
if isinstance(data, dict):
|
|
for k, v in data.items():
|
|
if len(prefix) > 0:
|
|
_walk_transform_dict(v, prefix + "." + k)
|
|
else:
|
|
_walk_transform_dict(v, k)
|
|
else:
|
|
assert isinstance(data, str), "Invalid transform {}".format(data)
|
|
mapping_list.append((prefix, data))
|
|
|
|
_walk_transform_dict(transform)
|
|
|
|
expanded_mapping_list: List[Tuple[str, Any]] = []
|
|
|
|
def _expand_mapping(data: CPM9GInputType, stars: List[str], path: List[str], target: List[str]):
|
|
if len(path) == 0:
|
|
num_stars = 0
|
|
for it in target:
|
|
if it == "*":
|
|
num_stars += 1
|
|
if num_stars != len(stars):
|
|
raise ValueError("Invalid transform {}".format(".".join(target)))
|
|
|
|
nw_tgt = []
|
|
num_stars = 0
|
|
for it in target:
|
|
if it == "*":
|
|
nw_tgt.append(stars[num_stars])
|
|
num_stars += 1
|
|
else:
|
|
nw_tgt.append(it)
|
|
expanded_mapping_list.append((".".join(nw_tgt), data))
|
|
else:
|
|
if not isinstance(data, dict):
|
|
raise ValueError("Invalid data {}".format(data))
|
|
if path[0] == "*":
|
|
for k, v in data.items():
|
|
_expand_mapping(v, stars + [k], path[1:], target)
|
|
else:
|
|
_expand_mapping(data[path[0]], stars, path[1:], target)
|
|
|
|
# expand mapping list
|
|
for tgt, src in mapping_list:
|
|
if src.startswith("$"):
|
|
# copy from src
|
|
_expand_mapping(data, [], src[1:].split("."), tgt.split("."))
|
|
else:
|
|
if "*" in tgt:
|
|
raise ValueError("Constant value is not allowed to have `*` in prefix")
|
|
expanded_mapping_list.append((tgt, src))
|
|
|
|
ret = {}
|
|
for tgt, val in expanded_mapping_list:
|
|
tgt = tgt.split(".")
|
|
cur = ret
|
|
while len(tgt) > 1:
|
|
cur = cur[tgt[0]]
|
|
tgt = tgt[1:]
|
|
cur[tgt[0]] = val
|
|
return ret
|
|
|
|
def data_to_id(self, data: Any):
|
|
return convert_data_to_id(self.tokenizer, data)
|
|
|
|
def _ensure_transform_function(self, module_name: str, transform_script_path: str) -> _TransformFunction:
|
|
module_name = "cpm_live.transforms.{}".format(module_name)
|
|
if transform_script_path not in self._transform_func_table:
|
|
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
|
spec = importlib.util.spec_from_loader(loader.name, loader)
|
|
if spec is None:
|
|
raise RuntimeError("spec is none! {}".format(module_name))
|
|
mod = importlib.util.module_from_spec(spec)
|
|
self._transform_func_table[transform_script_path] = {
|
|
"loader": loader,
|
|
"module": mod,
|
|
"last_m": 0,
|
|
}
|
|
|
|
transform_script_info = self._transform_func_table[transform_script_path]
|
|
curr_m_time = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
|
if curr_m_time > transform_script_info["last_m"]:
|
|
transform_script_info["last_m"] = curr_m_time
|
|
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
|
transform_func = getattr(transform_script_info["module"], "transform", None)
|
|
if transform_func is None:
|
|
|
|
def _empty_transform_func(data: CPM9GInputType, num_sample: int, r: random.Random):
|
|
raise NotImplementedError("Transform func for dataset {} not implemented".format(module_name))
|
|
|
|
return _empty_transform_func
|
|
else:
|
|
return transform_func
|
|
|
|
def build_instance(self, config: _MixedDatasetConfig):
|
|
_sample_weight = np.array(config["incontext_weight"], dtype=np.float32)
|
|
_sample_weight = _sample_weight / _sample_weight.sum()
|
|
num_incontext = np.random.choice(_sample_weight.shape[0], p=_sample_weight)
|
|
ds = config["dataset"]
|
|
transforms = config["transforms"]
|
|
if isinstance(transforms, str):
|
|
if not os.path.exists(transforms):
|
|
raise RuntimeError("transform script {} file not exists".format(transforms))
|
|
# load transform script
|
|
transform_func = self._ensure_transform_function(_dataset_identity(config), transforms)
|
|
seed = random.random()
|
|
|
|
def _transform(data: CPM9GInputType):
|
|
r = random.Random(seed)
|
|
return transform_func(data, num_incontext, r)
|
|
|
|
transform = _transform
|
|
elif len(transforms) == 0:
|
|
transform = None
|
|
else:
|
|
transform = transforms[np.random.choice(len(transforms))]
|
|
|
|
raw_data = {}
|
|
while True:
|
|
inp = ds.read()
|
|
inp = self.apply_transform(inp, transform)
|
|
# if None, skip this one
|
|
if inp is None:
|
|
continue
|
|
input_ids, context = self.data_to_id(inp)
|
|
if input_ids.shape[0] > self._clip_length:
|
|
continue # too long
|
|
input_ids = input_ids[: self._clip_length]
|
|
context = context[: self._clip_length]
|
|
raw_data["input"] = inp
|
|
raw_data["samples"] = []
|
|
break
|
|
|
|
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
|
i = 0
|
|
while True:
|
|
if i == num_incontext or input_ids.shape[0] >= self._clip_length:
|
|
break # early break
|
|
sample = ds.read()
|
|
sample = self.apply_transform(sample, transform)
|
|
# if sample is None, skip this one
|
|
if sample is None:
|
|
continue
|
|
sample_input_ids, _ = self.data_to_id(sample)
|
|
if input_ids.shape[0] + sample_input_ids.shape[0] > self._clip_length:
|
|
break # too long, break
|
|
raw_data["samples"].append(sample)
|
|
input_ids = np.concatenate([input_ids, sample_input_ids], axis=0)
|
|
context = np.concatenate([context, np.ones(sample_input_ids.shape, dtype=np.int8)], axis=0)
|
|
sample_ids = np.concatenate([sample_ids, np.full(sample_input_ids.shape, i + 1, dtype=np.int32)], axis=0)
|
|
|
|
i += 1
|
|
|
|
return (
|
|
input_ids,
|
|
context,
|
|
sample_ids,
|
|
raw_data,
|
|
)
|
|
|
|
def pack_batch(self, force: bool = False, unilm=False) -> CPM9GBatch:
|
|
# pack batch
|
|
if len(self._inputs) < self._batch_size:
|
|
if not force:
|
|
raise RuntimeError("Batch insufficient")
|
|
batch_size = len(self._inputs)
|
|
else:
|
|
batch_size = self._batch_size
|
|
max_length = self._max_length # self._spans[0][-1] if self._unpad else self._max_length
|
|
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
|
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
|
|
|
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
length = np.zeros((batch_size,), dtype=np.int32)
|
|
task_ids = np.full((batch_size, max_length), -1, dtype=np.int32)
|
|
if self._spans[0][-1] != max_length:
|
|
cu_seqlens = np.array([0] + self._spans[0] + [max_length], dtype=np.int32)
|
|
else:
|
|
cu_seqlens = np.array([0] + self._spans[0], dtype=np.int32)
|
|
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
|
|
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
|
|
all_task_names: Set[str] = set()
|
|
for i in range(batch_size):
|
|
for task_name in self._task_ids[i]:
|
|
all_task_names.add(task_name)
|
|
task_names: List[str] = list(all_task_names)
|
|
task_name_to_id = {name: i for i, name in enumerate(task_names)}
|
|
|
|
raw_data_list: List[Any] = []
|
|
for i in range(batch_size):
|
|
instance_length = self._inputs[i].shape[0]
|
|
inputs[i, :instance_length] = self._inputs[i]
|
|
sample_ids[i, :instance_length] = self._sample_ids[i]
|
|
if unilm:
|
|
context[i, :instance_length] = self._context[i]
|
|
|
|
span_begin = 0
|
|
for span_id, (span_end, task_name) in enumerate(zip(self._spans[i], self._task_ids[i])):
|
|
spans[i, span_begin:span_end] = span_id
|
|
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
|
|
task_ids[i, span_begin:span_end] = task_name_to_id[task_name]
|
|
span_begin = span_end
|
|
length[i] = instance_length
|
|
raw_data_list.extend(self._raw_data[i])
|
|
|
|
for j in range(instance_length):
|
|
if j > 1 and self._context[i][j] == 0:
|
|
if self._inputs[i][j] != self.tokenizer.bos_id and self._inputs[i][j - 1] != self.tokenizer.eos_id:
|
|
tgt[i, j - 1] = self._inputs[i][j]
|
|
|
|
self._inputs = self._inputs[batch_size:]
|
|
self._context = self._context[batch_size:]
|
|
self._sample_ids = self._sample_ids[batch_size:]
|
|
self._spans = self._spans[batch_size:]
|
|
self._task_ids = self._task_ids[batch_size:]
|
|
self._raw_data = self._raw_data[batch_size:]
|
|
return {
|
|
"inputs": inputs,
|
|
"length": length,
|
|
"context": context > 0,
|
|
"sample_ids": sample_ids,
|
|
"spans": spans,
|
|
"cu_seqlens": cu_seqlens,
|
|
"max_seqlen": max_seqlen,
|
|
"position_ids": position_ids,
|
|
"target": tgt,
|
|
"task_ids": task_ids,
|
|
"task_names": task_names,
|
|
"raw_data": raw_data_list,
|
|
}
|
|
|
|
def add_data(self, config: _MixedDatasetConfig) -> Optional[CPM9GBatch]:
|
|
(
|
|
input_ids,
|
|
context,
|
|
sample_ids,
|
|
raw_data,
|
|
) = self.build_instance(config)
|
|
|
|
# add to batch
|
|
best_fit: Union[None, int] = None
|
|
best_fit_space: Union[None, int] = None
|
|
for i in range(len(self._inputs)):
|
|
space = self._max_length - self._inputs[i].shape[0]
|
|
if input_ids.shape[0] <= space:
|
|
if best_fit_space is None:
|
|
best_fit = i
|
|
best_fit_space = space
|
|
elif best_fit_space > space:
|
|
best_fit = i
|
|
best_fit_space = space
|
|
if best_fit is None:
|
|
# add a new instance
|
|
self._inputs.append(input_ids)
|
|
self._context.append(context)
|
|
self._sample_ids.append(sample_ids)
|
|
self._spans.append([input_ids.shape[0]])
|
|
self._task_ids.append([config["task_name"]])
|
|
self._raw_data.append([raw_data])
|
|
else:
|
|
# add to existing instance
|
|
self._inputs[best_fit] = np.concatenate([self._inputs[best_fit], input_ids], axis=0)
|
|
self._context[best_fit] = np.concatenate([self._context[best_fit], context], axis=0)
|
|
self._sample_ids[best_fit] = np.concatenate([self._sample_ids[best_fit], sample_ids], axis=0)
|
|
self._spans[best_fit].append(self._inputs[best_fit].shape[0])
|
|
self._task_ids[best_fit].append(config["task_name"])
|
|
self._raw_data[best_fit].append(raw_data)
|
|
|
|
if len(self._inputs) > self._batch_size:
|
|
return self.pack_batch()
|
|
else:
|
|
return None # not ready
|
|
|
|
|
|
class _MixedDatasetConfigMananger:
|
|
def __init__(self, config_path: str) -> None:
|
|
self._config_path: str = config_path
|
|
self._config: Union[List[_MixedDatasetConfig], None] = None
|
|
self._last_m = 0
|
|
|
|
def changed(self):
|
|
while True:
|
|
try:
|
|
m_time = os.stat(self._config_path).st_mtime
|
|
if m_time > self._last_m:
|
|
# try to load new config
|
|
try:
|
|
self._config = json.load(open(self._config_path, "r", encoding="utf-8"))
|
|
except Exception:
|
|
# failed to load config
|
|
return False
|
|
# new config loaded
|
|
self._last_m = m_time
|
|
return True
|
|
return False
|
|
except Exception:
|
|
print("Error: reading info list! {}".format(self._config_path))
|
|
time.sleep(30)
|
|
|
|
def get_config(self) -> List[_MixedDatasetConfig]:
|
|
if self._config is None:
|
|
if not self.changed():
|
|
raise RuntimeError("Failed to load config")
|
|
if self._config is None:
|
|
raise RuntimeError("Failed to load config")
|
|
return self._config
|
|
|
|
|
|
def _mixed_dataset_process(
|
|
config_path: str,
|
|
q_cmd: multiprocessing.Queue,
|
|
q_cmd_out: multiprocessing.Queue,
|
|
q_data: multiprocessing.Queue,
|
|
rank: int,
|
|
world_size: int,
|
|
packer: _MixedDatasetBatchPacker,
|
|
max_repeat_times: int = None,
|
|
random_state=None,
|
|
):
|
|
for _ in range(rank):
|
|
np.random.random()
|
|
random.setstate(random_state)
|
|
# ignore SIGINT
|
|
import signal
|
|
|
|
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
config_base_path = os.path.dirname(os.path.abspath(config_path))
|
|
|
|
# fix libgomp threading deadlock after fork(), use single-thread in data process.
|
|
# REF: https://github.com/pytorch/pytorch/issues/17199
|
|
torch.set_num_threads(1)
|
|
|
|
def _convert_to_abs_path(transform_path: str):
|
|
if transform_path.startswith("/"):
|
|
return transform_path
|
|
else:
|
|
return os.path.join(config_base_path, transform_path)
|
|
|
|
def _build_sample_weights(config: List[_MixedDatasetConfig]):
|
|
if len(config) == 0:
|
|
return np.array([], dtype=np.float32)
|
|
weights = [c["weight"] * c["lines"] for c in config]
|
|
weights = np.array(weights, dtype=np.float32)
|
|
sm_weight = weights.sum()
|
|
if sm_weight > 0:
|
|
weights = weights / sm_weight
|
|
return weights
|
|
else:
|
|
raise RuntimeError("Empty datasets")
|
|
|
|
cfg_mgr = _MixedDatasetConfigMananger(config_path)
|
|
config = cfg_mgr.get_config()
|
|
|
|
for c in config:
|
|
ds = DistributedDataset(
|
|
_convert_to_abs_path(c["path"]),
|
|
rank,
|
|
world_size,
|
|
max_repeat_times=max_repeat_times,
|
|
)
|
|
|
|
c["lines"] = ds._nlines
|
|
c["dataset"] = ds
|
|
if "weight" not in c:
|
|
c["weight"] = 1.0
|
|
if "transforms" not in c:
|
|
c["transforms"] = []
|
|
elif isinstance(c["transforms"], str):
|
|
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
|
if "incontext_weight" not in c:
|
|
c["incontext_weight"] = [1.0]
|
|
|
|
weights = _build_sample_weights(config)
|
|
|
|
should_stop = False
|
|
should_start = False
|
|
|
|
while not should_stop:
|
|
# update config first
|
|
if cfg_mgr.changed():
|
|
path_ds_map: Dict[str, _MixedDatasetConfig] = {}
|
|
nw_path_set: Set[str] = set()
|
|
|
|
# load new config
|
|
nw_config = cfg_mgr.get_config()
|
|
|
|
# build path -> dataset map
|
|
for c in config:
|
|
path_ds_map[_dataset_identity(c)] = c
|
|
|
|
# add new datasets
|
|
for c in nw_config:
|
|
if _dataset_identity(c) in path_ds_map:
|
|
# update values only
|
|
if "weight" in c:
|
|
path_ds_map[_dataset_identity(c)]["weight"] = c["weight"]
|
|
if "transform" in c:
|
|
if isinstance(c["transforms"], str):
|
|
path_ds_map[_dataset_identity(c)]["transforms"] = _convert_to_abs_path(c["transforms"])
|
|
else:
|
|
path_ds_map[_dataset_identity(c)]["transforms"] = c["transforms"]
|
|
if "incontext_weight" in c:
|
|
path_ds_map[_dataset_identity(c)]["incontext_weight"] = c["incontext_weight"]
|
|
else:
|
|
# new dataset
|
|
ds = DistributedDataset(
|
|
_convert_to_abs_path(c["path"]),
|
|
rank,
|
|
world_size,
|
|
max_repeat_times=max_repeat_times,
|
|
)
|
|
c["lines"] = ds._nlines
|
|
c["dataset"] = ds
|
|
if "weight" not in c:
|
|
c["weight"] = 1.0
|
|
if "transforms" not in c:
|
|
c["transforms"] = []
|
|
elif isinstance(c["transforms"], str):
|
|
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
|
if "incontext_weight" not in c:
|
|
c["incontext_weight"] = [1.0]
|
|
path_ds_map[_dataset_identity(c)] = c
|
|
nw_path_set.add(_dataset_identity(c))
|
|
|
|
# remove unused datasets
|
|
for c in config:
|
|
if _dataset_identity(c) not in nw_path_set:
|
|
del path_ds_map[_dataset_identity(c)]
|
|
|
|
config: List[_MixedDatasetConfig] = []
|
|
for c in nw_config:
|
|
config.append(path_ds_map[_dataset_identity(c)])
|
|
del path_ds_map
|
|
del nw_path_set
|
|
del nw_config
|
|
|
|
weights = _build_sample_weights(config)
|
|
|
|
# get cmds
|
|
while True:
|
|
try:
|
|
cmd = q_cmd.get_nowait()
|
|
except Empty:
|
|
break
|
|
if cmd == "stop":
|
|
should_stop = True
|
|
q_cmd_out.put(True)
|
|
break
|
|
elif cmd == "state_dict":
|
|
ret = OrderedDict()
|
|
for c in config:
|
|
ds_name = _dataset_identity(c)
|
|
ret[ds_name] = c["dataset"]._state_dict()
|
|
q_cmd_out.put(ret)
|
|
elif cmd == "load_state_dict":
|
|
state_dict = q_cmd.get()
|
|
missing = []
|
|
for idx, c in enumerate(config):
|
|
ds_name = _dataset_identity(c)
|
|
print(f"loading {idx}/{len(config)} {ds_name}")
|
|
if ds_name in state_dict:
|
|
c["dataset"].load_state_dict(state_dict[ds_name], strict=False)
|
|
else:
|
|
# new dataset
|
|
missing.append(ds_name)
|
|
q_cmd_out.put(missing)
|
|
elif cmd == "start":
|
|
should_start = True
|
|
q_cmd_out.put(True)
|
|
else:
|
|
raise RuntimeError("Unknown command: {}".format(cmd))
|
|
|
|
if should_stop:
|
|
break
|
|
|
|
if not should_start:
|
|
# wait for start cmd
|
|
time.sleep(1)
|
|
continue
|
|
|
|
if len(config) == 0:
|
|
# no dataset available
|
|
time.sleep(1)
|
|
continue
|
|
|
|
if q_data.full():
|
|
# queue full
|
|
time.sleep(1)
|
|
continue
|
|
|
|
# sample a dataset
|
|
ds_id: int = 0
|
|
|
|
while True:
|
|
ds_id = np.random.choice(weights.shape[0], p=weights)
|
|
if config[ds_id]["dataset"]._nlines != config[ds_id]["lines"]:
|
|
# dataset size changed
|
|
for c in config:
|
|
c["lines"] = c["dataset"]._nlines
|
|
weights = _build_sample_weights(config)
|
|
continue
|
|
else:
|
|
break
|
|
|
|
batch = packer.add_data(config[ds_id])
|
|
if batch is not None:
|
|
# new batch comming
|
|
q_data.put(batch)
|
|
|
|
# clean queue
|
|
while True:
|
|
try:
|
|
q_data.get_nowait()
|
|
except Empty:
|
|
break
|
|
|
|
|
|
class MixedDataset:
|
|
def __init__(
|
|
self,
|
|
config_path: str,
|
|
batch_size: int,
|
|
max_length: int,
|
|
tokenizer: CPM9GTokenizer,
|
|
unpad: bool = False,
|
|
max_repeat_times: int = None,
|
|
) -> None:
|
|
self._q_cmd = multiprocessing.Queue()
|
|
self._q_cmd_out = multiprocessing.Queue()
|
|
self._q_data = multiprocessing.Queue(maxsize=2)
|
|
self._packer = _MixedDatasetBatchPacker(batch_size, max_length, tokenizer, unpad)
|
|
self._p = multiprocessing.Process(
|
|
target=_mixed_dataset_process,
|
|
args=(
|
|
config_path,
|
|
self._q_cmd,
|
|
self._q_cmd_out,
|
|
self._q_data,
|
|
bmt.rank(),
|
|
bmt.world_size(),
|
|
self._packer,
|
|
max_repeat_times,
|
|
random.getstate(),
|
|
),
|
|
)
|
|
self._p.start()
|
|
self._closed = False
|
|
|
|
def close(self):
|
|
if not self._closed:
|
|
self._closed = True
|
|
self._q_cmd.put("stop")
|
|
assert self._q_cmd_out.get(), "Failed to stop process"
|
|
self._p.join()
|
|
|
|
@property
|
|
def closed(self):
|
|
return self._closed
|
|
|
|
def start(self):
|
|
self._q_cmd.put("start")
|
|
return self._q_cmd_out.get()
|
|
|
|
def state_dict(self):
|
|
self._q_cmd.put("state_dict")
|
|
states = self._q_cmd_out.get()
|
|
if not isinstance(states, OrderedDict):
|
|
raise RuntimeError("Invalid state dict {}".format(states))
|
|
if bmt.world_size() == 1:
|
|
for val in states.values():
|
|
val["states"].unsqueeze_(0)
|
|
val["block"].unsqueeze_(0)
|
|
return states
|
|
|
|
ret = OrderedDict()
|
|
for k, v in states.items():
|
|
num_unused_block = v["states"].size(0)
|
|
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
|
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
|
|
if max_unused_blocks == 0:
|
|
max_unused_blocks = 1
|
|
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
|
gpu_states[:num_unused_block] = v["states"].cuda()
|
|
|
|
gpu_block = v["block"].cuda()
|
|
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
|
|
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
|
|
ret[k] = {"states": global_states, "block": global_block}
|
|
return ret
|
|
|
|
def load_state_dict(self, data: OrderedDict, strict: bool = False):
|
|
self._q_cmd.put("load_state_dict")
|
|
self._q_cmd.put(data)
|
|
missing = self._q_cmd_out.get()
|
|
if strict:
|
|
if len(missing) > 0:
|
|
raise RuntimeError("Missing dataset state: {}".format(missing))
|
|
return missing
|
|
|
|
def get(self) -> CPM9GBatch:
|
|
ret: CPM9GBatch = self._q_data.get(timeout=300) # type: ignore
|
|
if not isinstance(ret, dict):
|
|
raise RuntimeError("Invalid data {}".format(ret))
|
|
return ret
|
|
|
|
def __iter__(self):
|
|
while True:
|
|
yield self.get()
|
|
|
|
def __del__(self):
|
|
if not self.closed:
|
|
try:
|
|
self.close()
|
|
except Exception:
|
|
pass
|