CPM-9G-8B/FM_9G/fm9g/dragonfly/training_tasks/pretrain_indexed_9g.py

829 lines
33 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
import copy
import ctypes
import functools
import importlib
import json
import logging
import os
import random
from collections import defaultdict
from collections import OrderedDict
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.shared_memory import SharedMemory
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from numpy.typing import NDArray
from fm9g.dataset import PrefetchDecodeDataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.vdc_sampling import van_der_corput
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
logger = logging.getLogger(__name__)
IGNORE_TGT = -100
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
if cfg_json_str is not None:
cfgs = json.loads(cfg_json_str)
else:
with open(cfg_path, "r", encoding="utf-8") as fin:
cfgs = json.load(fin)
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
path_dict = None
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
try:
with open(platform_config_path, "r") as f:
platform_cfg = json.load(f)
path_dict = platform_cfg["dataset_map"]
if bmt.rank() == 0:
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
except Exception as e:
if bmt.rank() == 0:
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
task_name2dataset_name = dict()
for idx, cfg in enumerate(cfgs):
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
# to be delibrately annoying :)
if cfg["task_name"] in task_name2dataset_name:
raise ValueError(
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
)
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
assert "path" in cfg and isinstance(cfg["path"], str)
# if path_dict is not None:
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
# dealing with optional configs
if "weight" in cfg:
assert isinstance(cfg["weight"], (float, int))
else:
cfg["weight"] = 1.0
if "oversize_rule" in cfg:
assert cfg["oversize_rule"] in ("drop", "head", "segment")
else:
cfg["oversize_rule"] = "segment"
if "transforms" in cfg:
assert isinstance(cfg["transforms"], str)
# dealing with relative path
if not cfg["transforms"].startswith("/"):
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
if not cfg["transforms"]:
cfg["transforms"] = None
else:
cfg["transforms"] = None
if "incontext_weight" in cfg:
assert isinstance(cfg["incontext_weight"], (list, tuple))
else:
cfg["incontext_weight"] = [1.0]
cfg["id"] = idx
# dataset and iterator will be built
return cfgs
def data2ids(data, tokenizer, max_length):
text = "\n".join(
[
data.get("title", "").strip(),
data.get("question", "").strip(),
data.get("answer", "").strip(),
data.get("abstract", "").strip(),
data.get("text", "").strip(),
data.get("code", "").strip(),
]
).strip()
if not text:
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
yield from ()
return
# suppress the annoying warning from tokenizer
ids = (
[tokenizer.bos_id]
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
+ [tokenizer.eos_id]
)
src_ids = ids[0:-1]
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
if len(src_ids) > max_length:
for st in range(0, len(src_ids), max_length):
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
else:
yield src_ids, tgt_ids
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
assert oversize_rule in ("drop", "head", "segment")
if data is None:
yield from ()
return
if "output" not in data or not data["output"]:
yield from ()
return
if "input" not in data or data["input"] is None:
data["input"] = ""
src_ids = [tokenizer.bos_id]
tgt_ids = []
has_input = False
is_segment_reenter = False
# Use incremental tokenization to avoid waiting for a long document
MAX_CHUNK_LENGTH = max_length * 10
for part in ("input", "output"):
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
while l < len(data[part]):
try:
current_slice = data[part][l:r]
if not current_slice:
break
#token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
token_ids = tokenizer.encode(current_slice)
except:
#print("Error in data[part][l:r] {}".format(data))
yield from ()
return
if part == "input":
if len(token_ids) > 0:
has_input = True
if len(token_ids) >= max_length - 2: # input len must < max_length
yield from ()
return
src_ids.extend(token_ids)
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
l = r
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
else:
if len(token_ids) + len(tgt_ids) >= max_length:
if oversize_rule == "drop":
yield from ()
return
elif oversize_rule == "head":
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
elif oversize_rule == "segment":
instruction_rest_space = max_length - 1 - len(token_ids)
if has_input: # is instruction data
if (
do_compact
and len(src_ids) >= 128 # avoid too short instruction info lost
and instruction_rest_space / len(src_ids) > 0.8
): # can be squeezed into max length
inputs_len = len(src_ids)
keep_len = instruction_rest_space // 2
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_id)
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else: # else use head rule
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
else: # normal segment
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids)
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
yield src_ids[:max_length], tgt_ids[:max_length]
src_ids = src_ids[max_length:]
tgt_ids = tgt_ids[max_length:]
# sliding input str window
consumed_str = tokenizer.decode(selected_token_ids)
l += len(consumed_str)
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
is_segment_reenter = True
else:
if (is_segment_reenter and len(token_ids) > 8) or (
not is_segment_reenter and len(token_ids) > 0
): # is segmented LM data
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_id)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else:
yield from ()
return
class SegmentedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg,
tokenizer,
max_length=1024,
transform_func=None,
nthreads=1,
prefetch_slice=3,
slice_size=500,
do_compact=False,
):
super(SegmentedDataset, self).__init__()
self.segment = functools.partial(
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
)
self.cfg = cfg
self.max_length = max_length
self.nthreads = nthreads
self.transform_func = transform_func
self.prefetch_slice = prefetch_slice
self.slice_size = slice_size
self.abs_weight = cfg.get("abs_weight", None)
self.task_name = cfg["task_name"]
self.dataset_name = cfg["dataset_name"]
self.oversize_rule = cfg["oversize_rule"]
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
self.exhausted = False
self.iterator = None
self.counter = 0
self.allow_repeat = cfg.get("allow_repeat", True)
self.used = BitSet()
self.init_ave_tokens()
def init_ave_tokens(
self,
):
try:
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
except FileNotFoundError:
bmt.print_rank(
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
)
shm = SharedMemory(
create=True,
size=ctypes.sizeof(ctypes.c_float),
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
)
# 使用共享内存
shared_value = ctypes.c_float.from_buffer(shm.buf)
_ave_tokens = self.cfg.get(
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
)
if _ave_tokens > self.max_length:
_ave_tokens = self.max_length
bmt.print_rank(
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
_ave_tokens, self.max_length
)
)
shared_value.value = _ave_tokens
# 不再需要 shared_value 时,删除引用
del shared_value
# 现在可以安全地关闭共享内存
shm.close()
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
@property
def ave_tokens(
self,
):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
tmp = shared_value.value
del shared_value
existing_shm.close()
return tmp
def ave_tokens_update(self, length):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
if shared_value.value < 0:
shared_value.value = float(length)
else:
shared_value.value = 0.98 * shared_value.value + 0.02 * length
del shared_value
existing_shm.close()
def size(self):
return self.dataset.size()
def __iter__(self):
self.iterate()
return self
def reset(self):
self.exhausted = False
if self.iterator is not None:
self.iterator.close()
self.iterator = None
self.used = BitSet()
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
def transform(self, data: dict) -> dict:
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
weight = weight / weight.sum()
num_incontext = np.random.choice(weight.shape[0], p=weight)
return self.transform_func(data, num_incontext, random.Random())
def segment_iterate(self, sample_iter):
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
for src_ids, tgt_ids in self.segment(self.transform(data)):
self.ave_tokens_update(len(src_ids)) # 0 for input ids
yield src_ids, tgt_ids, index
def iterate(self):
# make the dataset itself an iterator
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
self.iterator = self.segment_iterate(sample_iter)
def __next__(self):
# advance the task iterator
if self.iterator is None:
self.iterate()
try:
return next(self.iterator)
except StopIteration:
self.exhausted = True
return None
def load_state_dict(self, state_dict):
if state_dict.get("exhausted", False):
self.exhausted = True
self.used = BitSet()
else:
used = state_dict.get("used", BitSet())
if len(used) == len(self.dataset):
self.exhausted = True
self.used = BitSet()
else:
self.exhausted = False
self.used = used
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
def state_dict(self):
if len(self.used) == len(self.dataset):
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
else:
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
def update_state(self, indice):
self.used.update(indice)
class MixedIndexedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg_path: str,
cfg_json_str,
tokenizer,
max_length,
weight_by_size: bool = True,
nthreads=5,
prefetch_slice=100,
parallel_loading=False,
vdc_sampling=False,
update_weights_frequency=1,
seed=42,
):
super(MixedIndexedDataset, self).__init__()
self.set_seed(seed + bmt.rank())
self.weight_by_size = weight_by_size
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_id
self.bos_token_id = self.tokenizer.bos_id
self.path2transform = dict()
self.task_dict = OrderedDict()
self.nthreads = nthreads
self.prefetch_slice = prefetch_slice
# useful for indexing
self.tasks = []
self.names = []
# ending of iteration
self.remain = 0
self.max_length = max_length
self.vdc_sampling = vdc_sampling
if self.vdc_sampling:
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
self.update_weights_frequency = update_weights_frequency
self.path2transform = dict()
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
if parallel_loading:
self.parallel_load(cfgs, max_workers=None)
else:
self.sequential_load(cfgs)
self.weights = None
self.update_weights()
def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_task(self, cfg):
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
return task
def sequential_load(self, cfgs):
self.cfgs = cfgs
for cfg in cfgs:
# python3.7 and later preserves insertion order to dictionary
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
self.task_dict[task.task_name] = task
self.tasks.append(task)
self.names.append(task.task_name)
self.remain += 1
self.weights = None
self.update_weights()
def load_state_dict(self, state_dict):
missing_keys = []
for name, task in self.task_dict.items():
if name in state_dict:
task.load_state_dict(state_dict[name])
else:
missing_keys.append(name)
self.update_weights()
return missing_keys
def save_state_dict(self, path):
state_dict = {}
for name, task in self.task_dict.items():
_state_dict = task.state_dict()
if isinstance(_state_dict["used"], BitSet):
bitset = _state_dict["used"]
_file_name = bitset.save(path)
_state_dict["used"] = _file_name
state_dict[name] = _state_dict
else:
state_dict[name] = task.state_dict()
torch.save(state_dict, path)
logger.info("Dataset state saved")
def update_states(self, task_ids, indice):
is_dict = isinstance(indice, dict)
uniq = torch.unique(task_ids)
for idx in uniq:
idx = idx.item()
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
self.tasks[idx].update_state(indexes)
def get_transform_func(self, module_name: str, transform_script_path):
if transform_script_path is None:
# allow null transform
return lambda data, num_incontext, rand: data
module_name = "fm9g_live.transforms.{}".format(module_name)
if transform_script_path not in self.path2transform:
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
spec = importlib.util.spec_from_loader(loader.name, loader)
if spec is None:
raise RuntimeError("Spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self.path2transform[transform_script_path] = {
"loader": loader,
"module": mod,
"last_mtime": 0,
}
transform_script_info = self.path2transform[transform_script_path]
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
if curr_mtime > transform_script_info["last_mtime"]:
transform_script_info["last_mtime"] = curr_mtime
transform_script_info["loader"].exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
return transform_func
def update_weights(self):
task0 = self.tasks[0]
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
weights = []
for task in self.tasks:
if task.exhausted:
weights.append(0)
else:
if task.ave_tokens == -1:
weights.append(task.abs_weight / self.max_length)
else:
weights.append(task.abs_weight / task.ave_tokens)
weights = np.array(weights)
else:
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
if self.weight_by_size:
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
weights *= sizes
self.weights = weights / weights.sum()
def __iter__(self):
for task in self.tasks:
task.iterate()
return self
def __next__(self):
step = 1
while True:
if self.remain == 0:
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
raise StopIteration
if self.vdc_sampling:
idx = next(self.vdc_gen)(self.weights)
else:
idx = np.random.choice(len(self.weights), p=self.weights)
data = next(self.tasks[idx])
if step % self.update_weights_frequency == 0:
self.update_weights()
if data is None:
if self.tasks[idx].allow_repeat:
# _runtime_ave = self.tasks[idx].ave_tokens
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
# self.tasks[idx] = SegmentedDataset(
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
# )
# self.tasks[idx].ave_tokens_update(_runtime_ave)
self.tasks[idx].reset()
else:
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
self.tasks[idx].exhaust = True
self.remain -= 1
continue
step += 1
return dict(
task_id=idx,
input=data[0],
target=data[1],
index=data[2],
is_long=self.tasks[idx].cfg.get("is_long", False),
)
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
self.max_total_length = batch_size * max_length
self.batch_size = 1
# setting compact=True concats segments orignated from the same input
# into a long sequence. the relative order of segments should be preserved
# in mixed_dataset, e.g.,
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
self.compact = compact
self.total_length = 0
self.task2seqs = defaultdict(list)
self.mixed_dataset = mixed_dataset
self._max_length = max_length
self._pose_prob = pose_prob
self._pose_scaling_factor = pose_scaling_factor
if self._pose_prob > 0.0:
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
else:
self._scaled_max_length = max_length
def put(self, sample):
self.total_length += len(sample["target"])
task_id = sample["task_id"]
if self.compact and self.task2seqs[task_id]:
last = self.task2seqs[task_id][-1]
if last["target"][-1] != self.mixed_dataset.eos_token_id:
# concatenate sequantial segments for longer context modeling: why not?
last["input"].extend(sample["input"])
last["target"].extend(sample["target"])
return
self.task2seqs[task_id].append(sample)
def _pose_preprocess(
self,
input_ids: NDArray[np.int32],
) -> NDArray[np.int32]:
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
"""
len_chunk = min(len(input_ids), self._max_length)
len_input = len(input_ids)
# Chunk input randomly to fit max_length if needed
lt1 = 0
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
# Generate PoSE position ids
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
len_position_ids = len(position_ids)
lt = 0
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
position_ids[: rt1 - lt1] += lt
position_ids[rt1 - lt1 :] += rt
return position_ids
def pop(self):
indexes = defaultdict(list)
lengths = []
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
span_begin = 0
for samples in self.task2seqs.values():
while samples:
sample = samples.pop()
span_end = span_begin + len(sample["input"])
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
_span_position_ids = self._pose_preprocess(sample["input"])
else:
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
lengths.append(len(sample["target"]))
indexes[int(sample["task_id"])].append(sample["index"])
self.total_length -= len(sample["target"])
span_begin = span_end
cu_seqlens = torch.cat(
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
dim=0,
).int()
batch = {
"inputs": inputs,
"targets": targets,
"task_ids": task_ids,
"indexes": indexes,
# adhere to flash attention interface
"cu_seqlens": cu_seqlens,
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
"lengths": torch.tensor(sum(lengths)).int(),
"task_names": self.mixed_dataset.names,
"position_ids": position_ids,
}
return batch
def will_be_full(self, sample):
return self.total_length + len(sample["target"]) > self.max_total_length
def __iter__(self):
for sample in self.mixed_dataset:
if self.will_be_full(sample):
yield self.pop()
self.put(sample)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader, tp_size=1, tp_rank=0):
self.loader = iter(loader)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
if self.tp_size > 1:
if self.tp_rank == 0:
data = next(self.loader)
print("Rank {}, Preload data done.".format(bmt.rank()))
d = {}
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
for key in data.keys():
if isinstance(data[key], torch.Tensor):
np_cur_data = data[key].cpu().numpy()
bs = np_cur_data.tobytes()
fb.write(bs)
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
elif isinstance(data[key], np.ndarray):
bs = data[key].tobytes()
fb.write(bs)
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
else:
d[key] = data[key]
try:
_ = json.dumps(d)
except TypeError:
print(d)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
json.dump(d, f)
bmt.synchronize()
if self.tp_rank != 0:
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
data = json.load(f)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
bs = fb.read()
offset = 0
for key in data.keys():
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
nw_offset = offset + data[key][2]
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
data[key][3:]
)
offset = nw_offset
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
nw_offset = offset + data[key][2]
data[key] = torch.from_numpy(
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
.reshape(data[key][3:])
.copy()
)
offset = nw_offset
self.data = data
else:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key].record_stream(torch.cuda.current_stream())
data = copy.deepcopy(self.data)
self.preload()
return data
def __iter__(self):
return self