forked from jiuyuan/CPM-9G-8B
add fm9g 2b and 8b models
This commit is contained in:
parent
03c55e1fee
commit
cfd2fca57c
|
@ -1,16 +0,0 @@
|
||||||
{
|
|
||||||
"vocab_size": 119696,
|
|
||||||
"dropout_p": 0.0,
|
|
||||||
"eps": 1e-05,
|
|
||||||
"half": true,
|
|
||||||
"use_flash_attn": true,
|
|
||||||
"flash_attn_mask_shape": "2d",
|
|
||||||
"dim_model": 4096,
|
|
||||||
"dim_ff": 12288,
|
|
||||||
"dim_head": 128,
|
|
||||||
"num_heads": 32,
|
|
||||||
"num_kv_heads": 32,
|
|
||||||
"num_layers": 48,
|
|
||||||
"activate_fn": "silu",
|
|
||||||
"scale": false
|
|
||||||
}
|
|
|
@ -1,14 +0,0 @@
|
||||||
{
|
|
||||||
"vocab_size": 119696,
|
|
||||||
"dropout_p": 0.0,
|
|
||||||
"eps": 1e-05,
|
|
||||||
"half": true,
|
|
||||||
"dim_model": 4096,
|
|
||||||
"dim_ff": 11008,
|
|
||||||
"dim_head": 128,
|
|
||||||
"num_heads": 32,
|
|
||||||
"num_kv_heads": 32,
|
|
||||||
"num_layers": 32,
|
|
||||||
"activate_fn": "silu",
|
|
||||||
"scale": false
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,12 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"dataset_name": "wikipedia",
|
|
||||||
"task_name": "wikipedia",
|
|
||||||
"weight": 1.0,
|
|
||||||
"path": "path/to/data",
|
|
||||||
"incontext_weight": [
|
|
||||||
1.0
|
|
||||||
],
|
|
||||||
"transforms": "/home/USERNAME/cpm9g/apps/cpm9g/config/datasets/wikipedia/script_cpmc.py"
|
|
||||||
}
|
|
||||||
]
|
|
|
@ -1,9 +0,0 @@
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def rand(n: int, r: random.Random):
|
|
||||||
return int(r.random() * n)
|
|
||||||
|
|
||||||
|
|
||||||
def transform(data, num_sample: int, r: random.Random):
|
|
||||||
return {"input": "", "output": data["text"]}
|
|
|
@ -1,485 +0,0 @@
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import torch
|
|
||||||
|
|
||||||
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
|
||||||
from cpm.arguments import get_args
|
|
||||||
from cpm.cpm9g.models import CPM9G
|
|
||||||
from cpm.cpm9g.models import CPM9GConfig
|
|
||||||
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
|
||||||
from cpm.cpm9g.training_tasks import MixedDataset
|
|
||||||
from cpm.utils import allgather_objects
|
|
||||||
from cpm.utils import exporter
|
|
||||||
from cpm.utils import logger
|
|
||||||
from cpm.utils import LogManager
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(args):
|
|
||||||
tokenizer = CPM9GTokenizer(path=args.vocab)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(args):
|
|
||||||
config = CPM9GConfig.from_json_file(args.model_config)
|
|
||||||
config.tp = 1 if args.tp != 1 else 0
|
|
||||||
if args.flash == "none":
|
|
||||||
config.use_flash_attn = False
|
|
||||||
else:
|
|
||||||
config.use_flash_attn = True
|
|
||||||
if args.flash == "1d":
|
|
||||||
config.flash_attn_mask_shape = "1d"
|
|
||||||
else:
|
|
||||||
config.flash_attn_mask_shape = "2d"
|
|
||||||
if args.flash == "triton":
|
|
||||||
config.flash_impl = "triton"
|
|
||||||
elif args.flash == "cuda":
|
|
||||||
config.flash_impl = "cuda"
|
|
||||||
model = CPM9G(config)
|
|
||||||
if args.load is not None:
|
|
||||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
|
||||||
bmt.load(model, args.load)
|
|
||||||
else:
|
|
||||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
|
||||||
bmt.init_parameters(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, model):
|
|
||||||
for name, para in model.named_parameters():
|
|
||||||
# if not ('input_embedding' in name or 'lm_head' in name):
|
|
||||||
# para.requires_grad_(False)
|
|
||||||
bmt.print_rank(name, para.requires_grad)
|
|
||||||
if args.offload:
|
|
||||||
optimizer = bmt.optim.AdamOffloadOptimizer(
|
|
||||||
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
|
||||||
if args.load is not None and args.load_grad:
|
|
||||||
start = time.time()
|
|
||||||
print(
|
|
||||||
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
|
|
||||||
== bmt.world_size()
|
|
||||||
):
|
|
||||||
file_name = os.path.join(
|
|
||||||
args.save,
|
|
||||||
args.save_name + "-{}.rank-{}.opt".format(args.start_step, bmt.rank()),
|
|
||||||
)
|
|
||||||
print(file_name)
|
|
||||||
if os.path.exists(file_name):
|
|
||||||
print("start to load grad ckpt {}".format(file_name))
|
|
||||||
states = torch.load(file_name)
|
|
||||||
optimizer.load_state_dict(states)
|
|
||||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
|
||||||
r"""
|
|
||||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
|
||||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_lr_warmup(self, num_iter) -> float:
|
|
||||||
return self.start_lr * num_iter / self.warmup_iter
|
|
||||||
|
|
||||||
def get_lr_decay(self, num_iter) -> float:
|
|
||||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
|
||||||
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
|
||||||
|
|
||||||
|
|
||||||
def get_learning_rate_scheduler(args, optimizer):
|
|
||||||
if args.lr_decay_iters is None:
|
|
||||||
args.lr_decay_iters = args.train_iters
|
|
||||||
# lr_scheduler = bmt.lr_scheduler.Noam(
|
|
||||||
lr_scheduler = Cosine(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=args.warmup_iters,
|
|
||||||
end_iter=args.lr_decay_iters,
|
|
||||||
num_iter=args.start_step,
|
|
||||||
)
|
|
||||||
return lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_optimizer(args):
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
model = get_model(args)
|
|
||||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
optimizer = get_optimizer(args, model)
|
|
||||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
return tokenizer, model, optimizer, lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
|
||||||
args = get_args(pretrain=True)
|
|
||||||
bmt.init_distributed(seed=args.seed, zero_level=3)
|
|
||||||
if args.save is not None:
|
|
||||||
os.makedirs(args.save, exist_ok=True)
|
|
||||||
if args.load is not None:
|
|
||||||
if args.start_step == 0:
|
|
||||||
args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def see_memory(detail=False):
|
|
||||||
if detail:
|
|
||||||
res = torch.cuda.memory_summary()
|
|
||||||
else:
|
|
||||||
res = (
|
|
||||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
|
||||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
|
||||||
)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def add_mem_time(info, mem_usage, tim_usage):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
bmt.synchronize()
|
|
||||||
mem_usage[info] = see_memory()
|
|
||||||
tim_usage[info] = time.time()
|
|
||||||
return mem_usage, tim_usage
|
|
||||||
|
|
||||||
|
|
||||||
class LossSpikeDetector:
|
|
||||||
def __init__(self, log_path: str) -> None:
|
|
||||||
self._last_loss: Dict[str, float] = {}
|
|
||||||
self._last_data: List[Any] = [None]
|
|
||||||
self._log_path = log_path
|
|
||||||
|
|
||||||
def update_data(self, data: Any):
|
|
||||||
self._last_data.append(data)
|
|
||||||
if len(self._last_data) > 2:
|
|
||||||
self._last_data = self._last_data[-2:]
|
|
||||||
|
|
||||||
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
|
||||||
loss_spike_result = []
|
|
||||||
for task, loss in loss_map.items():
|
|
||||||
if task in self._last_loss:
|
|
||||||
if loss > self._last_loss[task] * 3:
|
|
||||||
# loss spike!
|
|
||||||
loss_spike_result.append(
|
|
||||||
{
|
|
||||||
"prev": self._last_loss[task],
|
|
||||||
"curr": loss,
|
|
||||||
"task": task,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self._last_loss[task] = float(loss)
|
|
||||||
if len(loss_spike_result) > 0:
|
|
||||||
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
|
||||||
|
|
||||||
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
with open(self._log_path, "a", encoding="utf-8") as fp:
|
|
||||||
fp.write("=" * 20)
|
|
||||||
fp.write("\nloss spike at {}\n".format(iteration))
|
|
||||||
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
|
||||||
fp.write("data: \n")
|
|
||||||
for d in data:
|
|
||||||
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
|
||||||
fp.write("\n\n")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
print("cannot output log to the file {}", self._log_path)
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain(
|
|
||||||
args,
|
|
||||||
tokenizer: CPM9GTokenizer,
|
|
||||||
model: CPM9G,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
|
||||||
):
|
|
||||||
average_time = bmt.utils.AverageRecorder()
|
|
||||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
|
||||||
optim_manager = bmt.optim.OptimManager(
|
|
||||||
loss_scale=None if args.bf16 else args.loss_scale,
|
|
||||||
loss_scale_steps=args.loss_scale_steps,
|
|
||||||
loss_scale_factor=2,
|
|
||||||
max_loss_scale=args.max_loss_scale,
|
|
||||||
min_loss_scale=args.min_loss_scale,
|
|
||||||
)
|
|
||||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
start_step = args.start_step
|
|
||||||
lsd = LossSpikeDetector("./log/debug/spile.%d.log" % bmt.rank())
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
import distutils.version # noqa: F401
|
|
||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
if not os.path.exists(args.tensorboard):
|
|
||||||
os.makedirs(args.tensorboard)
|
|
||||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
|
||||||
|
|
||||||
if args.log_dir is not None and bmt.rank() == 0:
|
|
||||||
log_mgr = LogManager(args.log_dir)
|
|
||||||
|
|
||||||
global_token_pass = 0.0
|
|
||||||
global_world_size = bmt.world_size()
|
|
||||||
dataloader = MixedDataset(args.dataset, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"))
|
|
||||||
|
|
||||||
if args.load is not None:
|
|
||||||
dataset_states_path = args.load.replace(".pt", ".data")
|
|
||||||
if os.path.exists(dataset_states_path):
|
|
||||||
start = time.time()
|
|
||||||
bmt.print_rank("start to load data ckpt")
|
|
||||||
dataset_states = torch.load(dataset_states_path)
|
|
||||||
logger.info("load data ckpt in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
missing = dataloader.load_state_dict(dataset_states)
|
|
||||||
logger.info("load state dict in {:.2f}s".format(time.time() - start))
|
|
||||||
if len(missing) > 0:
|
|
||||||
bmt.print_rank("Missing keys when loading dataset states: ", missing)
|
|
||||||
else:
|
|
||||||
bmt.print_rank("cannot find data ckpt {}".format(dataset_states_path))
|
|
||||||
|
|
||||||
dataloader.start()
|
|
||||||
bmt.print_rank("finish dataset start")
|
|
||||||
try:
|
|
||||||
total = 0
|
|
||||||
hash = {}
|
|
||||||
for iteration, data in enumerate(dataloader):
|
|
||||||
iteration = iteration + start_step + 1
|
|
||||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
|
||||||
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
|
||||||
targets = torch.from_numpy(data["target"]).cuda().to(torch.int32)
|
|
||||||
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
|
||||||
task_names = data["task_names"]
|
|
||||||
lsd.update_data(data["raw_data"])
|
|
||||||
if args.flash == "cuda":
|
|
||||||
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
|
||||||
max_seqlen = data["max_seqlen"]
|
|
||||||
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
|
||||||
else:
|
|
||||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
|
||||||
input_context = torch.zeros_like(input_ids).cuda().bool()
|
|
||||||
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
optim_manager.zero_grad()
|
|
||||||
# torch.cuda.empty_cache()
|
|
||||||
mem_usage = {}
|
|
||||||
tim_usage = {}
|
|
||||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
|
||||||
# ===========
|
|
||||||
if args.flash == "cuda":
|
|
||||||
logits, _ = model(
|
|
||||||
input_ids,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
position_ids=position_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, _ = model(
|
|
||||||
input_ids,
|
|
||||||
input_length,
|
|
||||||
input_context,
|
|
||||||
input_span,
|
|
||||||
)
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
|
||||||
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
||||||
global_loss = bmt.sum_loss(loss).item()
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
|
||||||
# ===========
|
|
||||||
optim_manager.backward(loss)
|
|
||||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
|
||||||
# ===========
|
|
||||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
|
||||||
optim_manager.step()
|
|
||||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
|
||||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
|
||||||
|
|
||||||
# ==========
|
|
||||||
iter_time = tim_usage["optim"] - tim_usage["init"]
|
|
||||||
average_time.record(iter_time)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
task_num = len(task_names)
|
|
||||||
targets_tmp = targets.expand(task_num, -1, -1)
|
|
||||||
task = torch.arange(task_num, dtype=torch.int32, device="cuda")[:, None, None]
|
|
||||||
targets_tmp = torch.where(
|
|
||||||
task_ids == task,
|
|
||||||
targets_tmp,
|
|
||||||
torch.scalar_tensor(-100, dtype=torch.int32, device="cuda"),
|
|
||||||
)
|
|
||||||
|
|
||||||
task_loss_map: Dict[str, float] = {}
|
|
||||||
task_loss_tot: Dict[str, float] = {}
|
|
||||||
for i in range(task_num):
|
|
||||||
task_loss_map[task_names[i]] = loss_func(
|
|
||||||
logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1)
|
|
||||||
).item()
|
|
||||||
task_loss_tot[task_names[i]] = (targets_tmp[i, :].view(-1) >= 0).sum().float().item()
|
|
||||||
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
|
||||||
gatherd_task_loss_tot: List[Dict[str, float]] = allgather_objects(task_loss_tot)
|
|
||||||
|
|
||||||
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
|
||||||
global_task_loss_tot: Dict[str, Union[List[float], float]] = {}
|
|
||||||
|
|
||||||
for idx, local_task_loss_map in enumerate(gatherd_task_loss_map):
|
|
||||||
for task_name, task_loss in local_task_loss_map.items():
|
|
||||||
if task_name not in global_task_loss_map:
|
|
||||||
global_task_loss_map[task_name] = []
|
|
||||||
global_task_loss_map[task_name].append(task_loss)
|
|
||||||
for task_name, task_tot in gatherd_task_loss_tot[idx].items():
|
|
||||||
if task_name not in global_task_loss_tot:
|
|
||||||
global_task_loss_tot[task_name] = []
|
|
||||||
global_task_loss_tot[task_name].append(task_tot)
|
|
||||||
|
|
||||||
task_loss_map = {}
|
|
||||||
for task_name in sorted(list(global_task_loss_map.keys())):
|
|
||||||
avg_loss = 0.0
|
|
||||||
sum_token = sum(global_task_loss_tot[task_name])
|
|
||||||
for loss, token in zip(global_task_loss_map[task_name], global_task_loss_tot[task_name]):
|
|
||||||
avg_loss += loss * token / sum_token
|
|
||||||
task_loss_map[task_name] = avg_loss
|
|
||||||
|
|
||||||
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
|
||||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
|
||||||
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
|
||||||
avg_time = average_time.value
|
|
||||||
lsd.update_loss(iteration, task_loss_map)
|
|
||||||
|
|
||||||
for task_id in data["task_ids"]:
|
|
||||||
for task in task_id:
|
|
||||||
if task != -1:
|
|
||||||
if not data["task_names"][task] in hash:
|
|
||||||
hash[data["task_names"][task]] = 0
|
|
||||||
hash[data["task_names"][task]] += 1.0
|
|
||||||
total += 1.0
|
|
||||||
|
|
||||||
gathered_hash = allgather_objects(hash)
|
|
||||||
sum_total = sum(allgather_objects(total))
|
|
||||||
|
|
||||||
final_hash = defaultdict(int)
|
|
||||||
for local_hash in gathered_hash:
|
|
||||||
for task, num in local_hash.items():
|
|
||||||
final_hash[task] += num
|
|
||||||
|
|
||||||
# for i in final_hash:
|
|
||||||
# bmt.print_rank(i, final_hash[i] / sum_total)
|
|
||||||
# bmt.print_rank("=========================================")
|
|
||||||
|
|
||||||
train_info = {
|
|
||||||
"time": tim_usage["init"],
|
|
||||||
"iteration": iteration,
|
|
||||||
"loss": global_loss,
|
|
||||||
"lr": lr_scheduler.current_lr,
|
|
||||||
"lr_scale": int(optim_manager.loss_scale),
|
|
||||||
"time_usage": tim_usage,
|
|
||||||
"mem_usage": mem_usage,
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"token_max": local_total_rate,
|
|
||||||
"token_pass": global_token_pass,
|
|
||||||
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
|
||||||
"grad_norm": grad_norm.item(),
|
|
||||||
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
|
||||||
"num_gpus": global_world_size,
|
|
||||||
"task_loss": task_loss_map,
|
|
||||||
}
|
|
||||||
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
||||||
bmt.print_rank(
|
|
||||||
(
|
|
||||||
"| Iter: {:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | avg_time: {:.4f}; cur_time:{:.4f}={:.4f}+{:.4f} |"
|
|
||||||
+ " token/max: {:.4f} | mask/max: {:.4f} | grad_norm: {:.4f} | mem: {:.2f} |"
|
|
||||||
).format(
|
|
||||||
iteration,
|
|
||||||
global_loss,
|
|
||||||
lr_scheduler.current_lr,
|
|
||||||
int(optim_manager.loss_scale),
|
|
||||||
iter_time,
|
|
||||||
tim_usage["optim"] - tim_usage["init"],
|
|
||||||
tim_usage["backward"] - tim_usage["init"],
|
|
||||||
tim_usage["optim"] - tim_usage["backward"],
|
|
||||||
input_length.float().mean() / args.max_length / (args.batch_size if args.flash == "cuda" else 1),
|
|
||||||
(targets >= 0).sum(-1).float().mean()
|
|
||||||
/ args.max_length
|
|
||||||
/ (args.batch_size if args.flash == "cuda" else 1),
|
|
||||||
grad_norm,
|
|
||||||
max(mem_usage["forward"][1], mem_usage["backward"][1]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"| "
|
|
||||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
|
||||||
+ " |"
|
|
||||||
)
|
|
||||||
|
|
||||||
if iteration % args.inspect_iters == 0:
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
train_info["model_inspect"] = model_inspect
|
|
||||||
|
|
||||||
if args.log_dir is not None and bmt.rank() == 0:
|
|
||||||
log_mgr.write(**train_info)
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
|
||||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
|
||||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
|
||||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
|
||||||
for task_name, loss in task_loss_map.items():
|
|
||||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
|
||||||
|
|
||||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
|
||||||
if args.save is not None and iteration % args.save_iters == 0:
|
|
||||||
exporter.export(model, dataloader, optimizer, iteration, args, final_save=False)
|
|
||||||
|
|
||||||
if iteration >= args.train_iters:
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"train loop err: {e}")
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
dataloader.close()
|
|
||||||
exporter.export(model, dataloader, optimizer, -1, args, final_save=False)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = initialize()
|
|
||||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
|
||||||
bmt.print_rank("finish loading")
|
|
||||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,60 +0,0 @@
|
||||||
#! /bin/bash
|
|
||||||
#SBATCH --nodes=1
|
|
||||||
#SBATCH --ntasks-per-node=8
|
|
||||||
#SBATCH --gres=gpu:8
|
|
||||||
#SBATCH --cpus-per-task=8
|
|
||||||
#SBATCH --mem=512GB
|
|
||||||
|
|
||||||
# use 8 GPU for example, pretrain may need 32 GPU
|
|
||||||
export MASTER_ADDR=`hostname`
|
|
||||||
export MASTER_PORT=12345
|
|
||||||
|
|
||||||
mkdir -p /home/${USERNAME}/logs/debug
|
|
||||||
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
|
|
||||||
|
|
||||||
cd apps/cpm9g
|
|
||||||
CONFIG_NAME="config/11b"
|
|
||||||
# --------------- 运行参数 ---------------
|
|
||||||
OPTS=""
|
|
||||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
|
||||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
|
||||||
OPTS+=" --batch-size 4"
|
|
||||||
OPTS+=" --train-iters 400000"
|
|
||||||
OPTS+=" --save-iters 250"
|
|
||||||
OPTS+=" --save-name cpm9g_checkpoint"
|
|
||||||
OPTS+=" --max-length 4096"
|
|
||||||
OPTS+=" --lr 1.5e-5"
|
|
||||||
OPTS+=" --inspect-iters 100"
|
|
||||||
OPTS+=" --warmup-iters 2000"
|
|
||||||
OPTS+=" --lr-decay-style noam"
|
|
||||||
OPTS+=" --weight-decay 0.1"
|
|
||||||
OPTS+=" --clip-grad 1.0"
|
|
||||||
OPTS+=" --loss-scale 1048576"
|
|
||||||
OPTS+=" --loss-scale-steps 32"
|
|
||||||
OPTS+=" --offload"
|
|
||||||
OPTS+=" --flash cuda"
|
|
||||||
# OPTS+=" --load-grad"
|
|
||||||
|
|
||||||
# --------------- 写文件路径 ---------------
|
|
||||||
## checkpoint
|
|
||||||
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
|
|
||||||
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
|
|
||||||
|
|
||||||
## logs,/local/logs 等价于 /data/logs(软链)
|
|
||||||
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
|
|
||||||
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
|
|
||||||
|
|
||||||
# --------------- 读文件路径 ---------------
|
|
||||||
OPTS+=" --dataset config/datasets.json"
|
|
||||||
OPTS+=" --load ${CHECKPOINT}"
|
|
||||||
OPTS+=" --start-step 1"
|
|
||||||
|
|
||||||
# --------------- 透传参数 ---------------
|
|
||||||
OPTS+=" $@"
|
|
||||||
|
|
||||||
# --------------- 最终指令 ---------------
|
|
||||||
|
|
||||||
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
|
|
||||||
echo "${CMD}"
|
|
||||||
|
|
||||||
$CMD
|
|
|
@ -1,59 +0,0 @@
|
||||||
#! /bin/bash
|
|
||||||
#SBATCH --nodes=1
|
|
||||||
#SBATCH --ntasks-per-node=8
|
|
||||||
#SBATCH --gres=gpu:8
|
|
||||||
#SBATCH --cpus-per-task=8
|
|
||||||
|
|
||||||
# use 8 GPU for example, pretrain may need 32 GPU
|
|
||||||
export MASTER_ADDR=`hostname`
|
|
||||||
export MASTER_PORT=12345
|
|
||||||
|
|
||||||
mkdir -p /home/${USERNAME}/logs/debug
|
|
||||||
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
|
|
||||||
|
|
||||||
cd apps/cpm9g
|
|
||||||
CONFIG_NAME="config/7b"
|
|
||||||
# --------------- 运行参数 ---------------
|
|
||||||
OPTS=""
|
|
||||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
|
||||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
|
||||||
OPTS+=" --batch-size 4"
|
|
||||||
OPTS+=" --train-iters 400000"
|
|
||||||
OPTS+=" --save-iters 250"
|
|
||||||
OPTS+=" --save-name cpm9g_checkpoint"
|
|
||||||
OPTS+=" --max-length 4096"
|
|
||||||
OPTS+=" --lr 1.5e-5"
|
|
||||||
OPTS+=" --inspect-iters 100"
|
|
||||||
OPTS+=" --warmup-iters 2000"
|
|
||||||
OPTS+=" --lr-decay-style noam"
|
|
||||||
OPTS+=" --weight-decay 0.1"
|
|
||||||
OPTS+=" --clip-grad 1.0"
|
|
||||||
OPTS+=" --loss-scale 1048576"
|
|
||||||
OPTS+=" --loss-scale-steps 32"
|
|
||||||
OPTS+=" --offload"
|
|
||||||
OPTS+=" --flash cuda"
|
|
||||||
# OPTS+=" --load-grad"
|
|
||||||
|
|
||||||
# --------------- 写文件路径 ---------------
|
|
||||||
## checkpoint
|
|
||||||
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
|
|
||||||
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
|
|
||||||
|
|
||||||
## logs,/local/logs 等价于 /data/logs(软链)
|
|
||||||
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
|
|
||||||
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
|
|
||||||
|
|
||||||
# --------------- 读文件路径 ---------------
|
|
||||||
OPTS+=" --dataset config/datasets.json"
|
|
||||||
OPTS+=" --load ${CHECKPOINT}"
|
|
||||||
OPTS+=" --start-step 1"
|
|
||||||
|
|
||||||
# --------------- 透传参数 ---------------
|
|
||||||
OPTS+=" $@"
|
|
||||||
|
|
||||||
# --------------- 最终指令 ---------------
|
|
||||||
|
|
||||||
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
|
|
||||||
echo "${CMD}"
|
|
||||||
|
|
||||||
$CMD
|
|
|
@ -1,484 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The OpenBMB team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import torch
|
|
||||||
|
|
||||||
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
|
||||||
from cpm.arguments import get_args
|
|
||||||
from cpm.cpm9g.models import CPM9G
|
|
||||||
from cpm.cpm9g.models import CPM9GConfig
|
|
||||||
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
|
||||||
from cpm.cpm9g.training_tasks import FinetuneDataset
|
|
||||||
from cpm.utils import allgather_objects
|
|
||||||
from cpm.utils import logger
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
def get_tokenizer(args):
|
|
||||||
tokenizer = CPM9GTokenizer(path=args.vocab)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(args):
|
|
||||||
config = CPM9GConfig.from_json_file(args.model_config)
|
|
||||||
if args.flash == "none":
|
|
||||||
config.use_flash_attn = False
|
|
||||||
else:
|
|
||||||
config.use_flash_attn = True
|
|
||||||
if args.flash == "1d":
|
|
||||||
config.flash_attn_mask_shape = "1d"
|
|
||||||
else:
|
|
||||||
config.flash_attn_mask_shape = "2d"
|
|
||||||
if args.flash == "triton":
|
|
||||||
config.flash_impl = "triton"
|
|
||||||
elif args.flash == "cuda":
|
|
||||||
config.flash_impl = "cuda"
|
|
||||||
model = CPM9G(config)
|
|
||||||
if args.load is not None:
|
|
||||||
bmt.init_parameters(model)
|
|
||||||
bmt.synchronize()
|
|
||||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
|
||||||
bmt.load(model, args.load, strict=False)
|
|
||||||
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
else:
|
|
||||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
|
||||||
bmt.init_parameters(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, model):
|
|
||||||
if args.offload:
|
|
||||||
optimizer = bmt.optim.AdamOffloadOptimizer(
|
|
||||||
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
|
||||||
if args.load is not None and args.load_grad:
|
|
||||||
start = time.time()
|
|
||||||
print(
|
|
||||||
sum(
|
|
||||||
[
|
|
||||||
1
|
|
||||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
|
||||||
else 0
|
|
||||||
for i in os.listdir(args.save)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
sum(
|
|
||||||
[
|
|
||||||
1
|
|
||||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
|
||||||
else 0
|
|
||||||
for i in os.listdir(args.save)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
== bmt.world_size()
|
|
||||||
):
|
|
||||||
file_name = os.path.join(
|
|
||||||
args.save,
|
|
||||||
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
|
|
||||||
)
|
|
||||||
print(file_name)
|
|
||||||
if os.path.exists(file_name):
|
|
||||||
print("start to load grad ckpt {}".format(file_name))
|
|
||||||
states = torch.load(file_name)
|
|
||||||
optimizer.load_state_dict(states)
|
|
||||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
|
||||||
r"""
|
|
||||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
|
||||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_lr_warmup(self, num_iter) -> float:
|
|
||||||
return self.start_lr * num_iter / self.warmup_iter
|
|
||||||
|
|
||||||
def get_lr_decay(self, num_iter) -> float:
|
|
||||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
|
||||||
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
|
||||||
|
|
||||||
|
|
||||||
def get_learning_rate_scheduler(args, optimizer):
|
|
||||||
if args.lr_decay_iters is None:
|
|
||||||
args.lr_decay_iters = args.train_iters
|
|
||||||
# lr_scheduler = bmt.lr_scheduler.Noam(
|
|
||||||
lr_scheduler = Cosine(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=args.warmup_iters,
|
|
||||||
end_iter=args.lr_decay_iters,
|
|
||||||
num_iter=args.start_step,
|
|
||||||
)
|
|
||||||
return lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_optimizer(args):
|
|
||||||
start = time.time()
|
|
||||||
model = get_model(args)
|
|
||||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
optimizer = get_optimizer(args, model)
|
|
||||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
return tokenizer, model, optimizer, lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
|
||||||
args = get_args(finetune=True)
|
|
||||||
|
|
||||||
# hack
|
|
||||||
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
|
|
||||||
bmt.init_distributed(checkpointing=False, seed=args.seed)
|
|
||||||
else:
|
|
||||||
bmt.init_distributed(seed=args.seed)
|
|
||||||
if args.save is not None:
|
|
||||||
os.makedirs(args.save, exist_ok=True)
|
|
||||||
# if args.load is not None:
|
|
||||||
# if args.start_step == 0:
|
|
||||||
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def see_memory(detail=False):
|
|
||||||
if detail:
|
|
||||||
res = torch.cuda.memory_summary()
|
|
||||||
else:
|
|
||||||
res = (
|
|
||||||
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
|
|
||||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
|
||||||
)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def add_mem_time(info, mem_usage, tim_usage):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
mem_usage[info] = see_memory()
|
|
||||||
tim_usage[info] = time.time()
|
|
||||||
return mem_usage, tim_usage
|
|
||||||
|
|
||||||
|
|
||||||
class LossSpikeDetector:
|
|
||||||
def __init__(self, log_path: str) -> None:
|
|
||||||
self._last_loss: Dict[str, float] = {}
|
|
||||||
self._last_data: List[Any] = [None]
|
|
||||||
self._log_path = log_path
|
|
||||||
|
|
||||||
def update_data(self, data: Any):
|
|
||||||
self._last_data.append(data)
|
|
||||||
if len(self._last_data) > 2:
|
|
||||||
self._last_data = self._last_data[-2:]
|
|
||||||
|
|
||||||
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
|
||||||
loss_spike_result = []
|
|
||||||
for task, loss in loss_map.items():
|
|
||||||
if task in self._last_loss:
|
|
||||||
if loss > self._last_loss[task] * 3:
|
|
||||||
# loss spike!
|
|
||||||
loss_spike_result.append(
|
|
||||||
{
|
|
||||||
"prev": self._last_loss[task],
|
|
||||||
"curr": loss,
|
|
||||||
"task": task,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self._last_loss[task] = float(loss)
|
|
||||||
if len(loss_spike_result) > 0:
|
|
||||||
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
|
||||||
|
|
||||||
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
|
||||||
return
|
|
||||||
with open(self._log_path, "a", encoding="utf-8") as fp:
|
|
||||||
fp.write("=" * 20)
|
|
||||||
fp.write("\nloss spike at {}\n".format(iteration))
|
|
||||||
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
|
||||||
fp.write("data: \n")
|
|
||||||
for d in data:
|
|
||||||
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
|
||||||
fp.write("\n\n")
|
|
||||||
|
|
||||||
|
|
||||||
def finetune(
|
|
||||||
args,
|
|
||||||
bin_file: str,
|
|
||||||
tokenizer: CPM9GTokenizer,
|
|
||||||
model: CPM9G,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
|
||||||
):
|
|
||||||
average_time = bmt.utils.AverageRecorder()
|
|
||||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
|
||||||
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
|
|
||||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
import distutils.version # noqa: F401
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
if not os.path.exists(args.tensorboard):
|
|
||||||
os.makedirs(args.tensorboard)
|
|
||||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
|
||||||
|
|
||||||
global_token_pass = 0.0
|
|
||||||
global_world_size = bmt.world_size()
|
|
||||||
|
|
||||||
for epoch in range(args.epoch):
|
|
||||||
epoch = epoch + 1
|
|
||||||
last_data = None
|
|
||||||
dataloader = FinetuneDataset(
|
|
||||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
|
||||||
)
|
|
||||||
optim_manager.zero_grad()
|
|
||||||
for iteration, data in enumerate(dataloader):
|
|
||||||
iteration = iteration + 1
|
|
||||||
skip_this_batch = False
|
|
||||||
if data is None:
|
|
||||||
if last_data is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Dataset is too small, please use a smaller batch size or sequence length!"
|
|
||||||
)
|
|
||||||
data = last_data # use last data
|
|
||||||
skip_this_batch = True
|
|
||||||
else:
|
|
||||||
last_data = data
|
|
||||||
|
|
||||||
assert data["inputs"].shape[0] == args.batch_size
|
|
||||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
|
||||||
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
|
||||||
targets = torch.from_numpy(data["target"]).cuda().long()
|
|
||||||
# bmt.print_rank(input_ids[0].tolist())
|
|
||||||
# bmt.print_rank(targets[0].tolist())
|
|
||||||
# bmt.print_rank(data["spans"].tolist())
|
|
||||||
# bmt.print_rank(tokenizer.decode(input_ids[0]))
|
|
||||||
# bmt.print_rank(tokenizer.path)
|
|
||||||
# bmt.synchronize()
|
|
||||||
# exit()
|
|
||||||
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
|
||||||
task_names = data["task_names"]
|
|
||||||
if args.flash == "cuda":
|
|
||||||
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
|
||||||
max_seqlen = data["max_seqlen"]
|
|
||||||
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
|
||||||
else:
|
|
||||||
input_context = torch.zeros_like(input_ids).cuda().bool()
|
|
||||||
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
# optim_manager.zero_grad()
|
|
||||||
# torch.cuda.empty_cache()
|
|
||||||
mem_usage = {}
|
|
||||||
tim_usage = {}
|
|
||||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
if args.flash == "cuda":
|
|
||||||
logits, _ = model(
|
|
||||||
input_ids,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
position_ids=position_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, _ = model(
|
|
||||||
input_ids,
|
|
||||||
input_length,
|
|
||||||
input_context,
|
|
||||||
input_span,
|
|
||||||
)
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
|
||||||
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
||||||
if skip_this_batch:
|
|
||||||
loss = loss * 0
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
optim_manager.backward(loss)
|
|
||||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
if iteration % args.gradient_accumulation_steps == 0:
|
|
||||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
|
||||||
optim_manager.step()
|
|
||||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
|
||||||
optim_manager.zero_grad()
|
|
||||||
else:
|
|
||||||
grad_norm = None
|
|
||||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ==========
|
|
||||||
iter_time = tim_usage["optim"] - tim_usage["init"]
|
|
||||||
average_time.record(iter_time)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
task_num = len(task_names)
|
|
||||||
targets_tmp = targets.expand(task_num, -1, -1)
|
|
||||||
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
|
|
||||||
targets_tmp = torch.where(
|
|
||||||
task_ids == task,
|
|
||||||
targets_tmp,
|
|
||||||
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
|
|
||||||
)
|
|
||||||
|
|
||||||
task_loss_map: Dict[str, float] = {}
|
|
||||||
if not skip_this_batch:
|
|
||||||
for i in range(task_num):
|
|
||||||
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
|
|
||||||
task_loss_map[task_names[i]] = task_loss.item()
|
|
||||||
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
|
||||||
|
|
||||||
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
|
||||||
for local_task_loss_map in gatherd_task_loss_map:
|
|
||||||
for task_name, task_loss in local_task_loss_map.items():
|
|
||||||
if task_name not in global_task_loss_map:
|
|
||||||
global_task_loss_map[task_name] = []
|
|
||||||
global_task_loss_map[task_name].append(task_loss)
|
|
||||||
|
|
||||||
task_loss_map = {}
|
|
||||||
for task_name in sorted(list(global_task_loss_map.keys())):
|
|
||||||
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
|
|
||||||
task_loss_map[task_name] = avg_loss
|
|
||||||
|
|
||||||
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
|
||||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
|
||||||
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
|
||||||
avg_time = average_time.value
|
|
||||||
|
|
||||||
train_info = {
|
|
||||||
"time": tim_usage["init"],
|
|
||||||
"epoch": epoch,
|
|
||||||
"iteration": iteration,
|
|
||||||
"loss": task_loss_map[args.task_name],
|
|
||||||
"lr": lr_scheduler.current_lr,
|
|
||||||
"lr_scale": int(optim_manager.loss_scale),
|
|
||||||
"time_usage": tim_usage,
|
|
||||||
"mem_usage": mem_usage,
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"token_max": local_total_rate,
|
|
||||||
"token_pass": global_token_pass,
|
|
||||||
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
|
||||||
"grad_norm": 0 if grad_norm == None else grad_norm.item(),
|
|
||||||
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
|
||||||
"num_gpus": global_world_size,
|
|
||||||
"task_loss": task_loss_map,
|
|
||||||
}
|
|
||||||
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
||||||
bmt.print_rank(
|
|
||||||
(
|
|
||||||
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
|
|
||||||
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
|
|
||||||
).format(
|
|
||||||
epoch,
|
|
||||||
iteration,
|
|
||||||
args.train_iters,
|
|
||||||
task_loss_map[args.task_name],
|
|
||||||
lr_scheduler.current_lr,
|
|
||||||
int(optim_manager.loss_scale),
|
|
||||||
avg_time,
|
|
||||||
input_length.float().mean() / args.max_length,
|
|
||||||
(targets >= 0).sum(-1).float().mean() / args.max_length,
|
|
||||||
0 if grad_norm == None else grad_norm,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"| "
|
|
||||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
|
||||||
+ " |"
|
|
||||||
)
|
|
||||||
if iteration % args.inspect_iters == 0:
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
train_info["model_inspect"] = model_inspect
|
|
||||||
|
|
||||||
# save_folder_name = f"{args.save}{epoch}{iteration}"
|
|
||||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
|
|
||||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
|
||||||
# if bmt.rank() == 0:
|
|
||||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
|
||||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
|
|
||||||
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
|
|
||||||
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
|
|
||||||
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", 0 if grad_norm == None else grad_norm.item(), iteration)
|
|
||||||
|
|
||||||
# if iteration % 10 == 0:
|
|
||||||
# save_folder_name = f"{args.save}{epoch}"
|
|
||||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
|
||||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
|
||||||
# bmt.save(model, model_fname)
|
|
||||||
# if bmt.rank() == 0:
|
|
||||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
|
||||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
|
||||||
|
|
||||||
save_folder_name = f"{args.save}-epoch-{epoch}"
|
|
||||||
model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
|
||||||
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
|
||||||
bmt.save(model, model_fname)
|
|
||||||
if bmt.rank() == 0:
|
|
||||||
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
|
||||||
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = initialize()
|
|
||||||
# To Be Specified
|
|
||||||
bin_file = args.dataset
|
|
||||||
bmt.print_rank(f"dataset: {bin_file}")
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
dataloader = FinetuneDataset(
|
|
||||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
|
||||||
)
|
|
||||||
bmt.print_rank(f"#batch: {len(dataloader)}")
|
|
||||||
total_steps = len(dataloader) * args.epoch
|
|
||||||
|
|
||||||
setattr(args, 'train_iters', int(total_steps))
|
|
||||||
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
|
|
||||||
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
|
|
||||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
|
||||||
|
|
||||||
bmt.print_rank("finish loading")
|
|
||||||
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,55 +0,0 @@
|
||||||
#! /bin/bash
|
|
||||||
#SBATCH --partition=gpu3-1
|
|
||||||
#SBATCH --nodes=1
|
|
||||||
#SBATCH --ntasks-per-node=8
|
|
||||||
#SBATCH --gres=gpu:8
|
|
||||||
#SBATCH --cpus-per-task=8
|
|
||||||
|
|
||||||
export MASTER_ADDR=g3002
|
|
||||||
export MASTER_PORT=12345
|
|
||||||
|
|
||||||
CPM_PATH="/data/groups/QY_LLM_Core/projects/202311-release/Models/11B-Chat/9G-Train"
|
|
||||||
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
|
|
||||||
EXP_PATH=./exp
|
|
||||||
mkdir -p $EXP_PATH
|
|
||||||
MODEL_NAME="cpm9g-11b-sft"
|
|
||||||
|
|
||||||
OPTS=""
|
|
||||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
|
||||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
|
||||||
|
|
||||||
OPTS+=" --train-iters 10000"
|
|
||||||
OPTS+=" --inspect-iters 200"
|
|
||||||
OPTS+=" --warmup-iters 500"
|
|
||||||
|
|
||||||
OPTS+=" --lr-decay-style cosine"
|
|
||||||
OPTS+=" --weight-decay 0.1"
|
|
||||||
OPTS+=" --clip-grad 1.0"
|
|
||||||
OPTS+=" --loss-scale 1048576"
|
|
||||||
OPTS+=" --max-loss-scale 33554432"
|
|
||||||
OPTS+=" --min-loss-scale 1"
|
|
||||||
OPTS+=" --loss-scale-steps 32"
|
|
||||||
|
|
||||||
OPTS+=" --offload"
|
|
||||||
OPTS+=" --batch-size 4"
|
|
||||||
OPTS+=" --max-length 4096"
|
|
||||||
OPTS+=" --lr 2e-5"
|
|
||||||
OPTS+=" --start-step 0"
|
|
||||||
OPTS+=" --epoch 8"
|
|
||||||
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
|
|
||||||
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
|
|
||||||
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
|
|
||||||
OPTS+=" --save ${EXP_PATH}/checkpoints"
|
|
||||||
OPTS+=" --save-name ${MODEL_NAME}"
|
|
||||||
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
|
|
||||||
# OPTS+=" --flash triton"
|
|
||||||
# OPTS+=" --flash cuda"
|
|
||||||
# OPTS+=" --load-grad"
|
|
||||||
|
|
||||||
OPTS+=" $@"
|
|
||||||
|
|
||||||
|
|
||||||
CMD="torchrun --nnodes=4 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g.py ${OPTS}"
|
|
||||||
|
|
||||||
echo "${CMD}"
|
|
||||||
$CMD
|
|
|
@ -1,515 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The OpenBMB team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import torch
|
|
||||||
|
|
||||||
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
|
||||||
from cpm.arguments import get_args
|
|
||||||
from cpm.cpm9g.models import CPM9G
|
|
||||||
from cpm.cpm9g.models import CPM9GConfig
|
|
||||||
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
|
||||||
from cpm.cpm9g.training_tasks import FinetuneDataset
|
|
||||||
from cpm.utils import allgather_objects
|
|
||||||
from cpm.utils import logger
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
import opendelta as od
|
|
||||||
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
|
|
||||||
from opendelta.utils.inspect import inspect_optimizer_statistics
|
|
||||||
from bigmodelvis import Visualization
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(args):
|
|
||||||
tokenizer = CPM9GTokenizer(path=args.vocab)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(args):
|
|
||||||
config = CPM9GConfig.from_json_file(args.model_config)
|
|
||||||
if args.flash == "none":
|
|
||||||
config.use_flash_attn = False
|
|
||||||
else:
|
|
||||||
config.use_flash_attn = True
|
|
||||||
if args.flash == "1d":
|
|
||||||
config.flash_attn_mask_shape = "1d"
|
|
||||||
else:
|
|
||||||
config.flash_attn_mask_shape = "2d"
|
|
||||||
if args.flash == "triton":
|
|
||||||
config.flash_impl = "triton"
|
|
||||||
elif args.flash == "cuda":
|
|
||||||
config.flash_impl = "cuda"
|
|
||||||
model = CPM9G(config)
|
|
||||||
if args.load is not None:
|
|
||||||
bmt.init_parameters(model)
|
|
||||||
bmt.synchronize()
|
|
||||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
|
||||||
bmt.load(model, args.load, strict=False)
|
|
||||||
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
else:
|
|
||||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
|
||||||
bmt.init_parameters(model)
|
|
||||||
|
|
||||||
if args.delta_type != None:
|
|
||||||
from bigmodelvis import Visualization #0813
|
|
||||||
if bmt.rank() == 0:
|
|
||||||
Visualization(model).structure_graph()
|
|
||||||
print("finetuned layers: ")
|
|
||||||
print(args.lora_layer)
|
|
||||||
|
|
||||||
if args.delta_type == "lora":
|
|
||||||
delta_model = LoraModel(backbone_model=model, modified_modules=args.lora_layer, backend='bmt', lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
|
|
||||||
|
|
||||||
if bmt.rank() == 0:
|
|
||||||
print("Before freeze: ")
|
|
||||||
delta_model.log()
|
|
||||||
|
|
||||||
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
|
|
||||||
|
|
||||||
if bmt.rank() == 0:
|
|
||||||
print("After freeze: ")
|
|
||||||
delta_model.log()
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, model):
|
|
||||||
if args.offload:
|
|
||||||
optimizer = bmt.optim.AdamOffloadOptimizer(
|
|
||||||
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
|
||||||
if args.load is not None and args.load_grad:
|
|
||||||
start = time.time()
|
|
||||||
print(
|
|
||||||
sum(
|
|
||||||
[
|
|
||||||
1
|
|
||||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
|
||||||
else 0
|
|
||||||
for i in os.listdir(args.save)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
sum(
|
|
||||||
[
|
|
||||||
1
|
|
||||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
|
||||||
else 0
|
|
||||||
for i in os.listdir(args.save)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
== bmt.world_size()
|
|
||||||
):
|
|
||||||
file_name = os.path.join(
|
|
||||||
args.save,
|
|
||||||
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
|
|
||||||
)
|
|
||||||
print(file_name)
|
|
||||||
if os.path.exists(file_name):
|
|
||||||
print("start to load grad ckpt {}".format(file_name))
|
|
||||||
states = torch.load(file_name)
|
|
||||||
optimizer.load_state_dict(states)
|
|
||||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
|
||||||
r"""
|
|
||||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
|
||||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_lr_warmup(self, num_iter) -> float:
|
|
||||||
return self.start_lr * num_iter / self.warmup_iter
|
|
||||||
|
|
||||||
def get_lr_decay(self, num_iter) -> float:
|
|
||||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
|
||||||
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
|
||||||
|
|
||||||
|
|
||||||
def get_learning_rate_scheduler(args, optimizer):
|
|
||||||
if args.lr_decay_iters is None:
|
|
||||||
args.lr_decay_iters = args.train_iters
|
|
||||||
# lr_scheduler = bmt.lr_scheduler.Noam(
|
|
||||||
lr_scheduler = Cosine(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=args.warmup_iters,
|
|
||||||
end_iter=args.lr_decay_iters,
|
|
||||||
num_iter=args.start_step,
|
|
||||||
)
|
|
||||||
return lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_optimizer(args):
|
|
||||||
start = time.time()
|
|
||||||
model = get_model(args)
|
|
||||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
optimizer = get_optimizer(args, model)
|
|
||||||
|
|
||||||
if args.delta_type != None:
|
|
||||||
inspect_optimizer_statistics(optimizer)
|
|
||||||
|
|
||||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
return tokenizer, model, optimizer, lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
|
||||||
args = get_args(finetune=True)
|
|
||||||
|
|
||||||
# hack
|
|
||||||
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
|
|
||||||
bmt.init_distributed(checkpointing=False, seed=args.seed)
|
|
||||||
else:
|
|
||||||
bmt.init_distributed(seed=args.seed)
|
|
||||||
if args.save is not None:
|
|
||||||
os.makedirs(args.save, exist_ok=True)
|
|
||||||
# if args.load is not None:
|
|
||||||
# if args.start_step == 0:
|
|
||||||
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def see_memory(detail=False):
|
|
||||||
if detail:
|
|
||||||
res = torch.cuda.memory_summary()
|
|
||||||
else:
|
|
||||||
res = (
|
|
||||||
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
|
|
||||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
|
||||||
)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def add_mem_time(info, mem_usage, tim_usage):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
mem_usage[info] = see_memory()
|
|
||||||
tim_usage[info] = time.time()
|
|
||||||
return mem_usage, tim_usage
|
|
||||||
|
|
||||||
|
|
||||||
class LossSpikeDetector:
|
|
||||||
def __init__(self, log_path: str) -> None:
|
|
||||||
self._last_loss: Dict[str, float] = {}
|
|
||||||
self._last_data: List[Any] = [None]
|
|
||||||
self._log_path = log_path
|
|
||||||
|
|
||||||
def update_data(self, data: Any):
|
|
||||||
self._last_data.append(data)
|
|
||||||
if len(self._last_data) > 2:
|
|
||||||
self._last_data = self._last_data[-2:]
|
|
||||||
|
|
||||||
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
|
||||||
loss_spike_result = []
|
|
||||||
for task, loss in loss_map.items():
|
|
||||||
if task in self._last_loss:
|
|
||||||
if loss > self._last_loss[task] * 3:
|
|
||||||
# loss spike!
|
|
||||||
loss_spike_result.append(
|
|
||||||
{
|
|
||||||
"prev": self._last_loss[task],
|
|
||||||
"curr": loss,
|
|
||||||
"task": task,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self._last_loss[task] = float(loss)
|
|
||||||
if len(loss_spike_result) > 0:
|
|
||||||
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
|
||||||
|
|
||||||
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
|
||||||
return
|
|
||||||
with open(self._log_path, "a", encoding="utf-8") as fp:
|
|
||||||
fp.write("=" * 20)
|
|
||||||
fp.write("\nloss spike at {}\n".format(iteration))
|
|
||||||
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
|
||||||
fp.write("data: \n")
|
|
||||||
for d in data:
|
|
||||||
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
|
||||||
fp.write("\n\n")
|
|
||||||
|
|
||||||
|
|
||||||
def finetune(
|
|
||||||
args,
|
|
||||||
bin_file: str,
|
|
||||||
tokenizer: CPM9GTokenizer,
|
|
||||||
model: CPM9G,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
|
||||||
):
|
|
||||||
average_time = bmt.utils.AverageRecorder()
|
|
||||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
|
||||||
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
|
|
||||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
import distutils.version # noqa: F401
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
if not os.path.exists(args.tensorboard):
|
|
||||||
os.makedirs(args.tensorboard)
|
|
||||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
|
||||||
|
|
||||||
global_token_pass = 0.0
|
|
||||||
global_world_size = bmt.world_size()
|
|
||||||
|
|
||||||
for epoch in range(args.epoch):
|
|
||||||
epoch = epoch + 1
|
|
||||||
last_data = None
|
|
||||||
dataloader = FinetuneDataset(
|
|
||||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
|
||||||
)
|
|
||||||
for iteration, data in enumerate(dataloader):
|
|
||||||
iteration = iteration + 1
|
|
||||||
skip_this_batch = False
|
|
||||||
if data is None:
|
|
||||||
if last_data is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Dataset is too small, please use a smaller batch size or sequence length!"
|
|
||||||
)
|
|
||||||
data = last_data # use last data
|
|
||||||
skip_this_batch = True
|
|
||||||
else:
|
|
||||||
last_data = data
|
|
||||||
|
|
||||||
assert data["inputs"].shape[0] == args.batch_size
|
|
||||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
|
||||||
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
|
||||||
targets = torch.from_numpy(data["target"]).cuda().long()
|
|
||||||
# bmt.print_rank(input_ids[0].tolist())
|
|
||||||
# bmt.print_rank(targets[0].tolist())
|
|
||||||
# bmt.print_rank(data["spans"].tolist())
|
|
||||||
# bmt.print_rank(tokenizer.decode(input_ids[0]))
|
|
||||||
# bmt.print_rank(tokenizer.path)
|
|
||||||
# bmt.synchronize()
|
|
||||||
# exit()
|
|
||||||
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
|
||||||
task_names = data["task_names"]
|
|
||||||
if args.flash == "cuda":
|
|
||||||
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
|
||||||
max_seqlen = data["max_seqlen"]
|
|
||||||
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
|
||||||
else:
|
|
||||||
input_context = torch.zeros_like(input_ids).cuda().bool()
|
|
||||||
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
optim_manager.zero_grad()
|
|
||||||
# torch.cuda.empty_cache()
|
|
||||||
mem_usage = {}
|
|
||||||
tim_usage = {}
|
|
||||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
if args.flash == "cuda":
|
|
||||||
logits, _ = model(
|
|
||||||
input_ids,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
position_ids=position_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, _ = model(
|
|
||||||
input_ids,
|
|
||||||
input_length,
|
|
||||||
input_context,
|
|
||||||
input_span,
|
|
||||||
)
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
|
||||||
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
||||||
if skip_this_batch:
|
|
||||||
loss = loss * 0
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
optim_manager.backward(loss)
|
|
||||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ===========
|
|
||||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
|
||||||
optim_manager.step()
|
|
||||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
# ==========
|
|
||||||
iter_time = tim_usage["optim"] - tim_usage["init"]
|
|
||||||
average_time.record(iter_time)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
task_num = len(task_names)
|
|
||||||
targets_tmp = targets.expand(task_num, -1, -1)
|
|
||||||
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
|
|
||||||
targets_tmp = torch.where(
|
|
||||||
task_ids == task,
|
|
||||||
targets_tmp,
|
|
||||||
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
|
|
||||||
)
|
|
||||||
|
|
||||||
task_loss_map: Dict[str, float] = {}
|
|
||||||
if not skip_this_batch:
|
|
||||||
for i in range(task_num):
|
|
||||||
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
|
|
||||||
task_loss_map[task_names[i]] = task_loss.item()
|
|
||||||
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
|
||||||
|
|
||||||
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
|
||||||
for local_task_loss_map in gatherd_task_loss_map:
|
|
||||||
for task_name, task_loss in local_task_loss_map.items():
|
|
||||||
if task_name not in global_task_loss_map:
|
|
||||||
global_task_loss_map[task_name] = []
|
|
||||||
global_task_loss_map[task_name].append(task_loss)
|
|
||||||
|
|
||||||
task_loss_map = {}
|
|
||||||
for task_name in sorted(list(global_task_loss_map.keys())):
|
|
||||||
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
|
|
||||||
task_loss_map[task_name] = avg_loss
|
|
||||||
|
|
||||||
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
|
||||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
|
||||||
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
|
||||||
avg_time = average_time.value
|
|
||||||
|
|
||||||
train_info = {
|
|
||||||
"time": tim_usage["init"],
|
|
||||||
"epoch": epoch,
|
|
||||||
"iteration": iteration,
|
|
||||||
"loss": task_loss_map[args.task_name],
|
|
||||||
"lr": lr_scheduler.current_lr,
|
|
||||||
"lr_scale": int(optim_manager.loss_scale),
|
|
||||||
"time_usage": tim_usage,
|
|
||||||
"mem_usage": mem_usage,
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"token_max": local_total_rate,
|
|
||||||
"token_pass": global_token_pass,
|
|
||||||
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
|
||||||
"grad_norm": grad_norm.item(),
|
|
||||||
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
|
||||||
"num_gpus": global_world_size,
|
|
||||||
"task_loss": task_loss_map,
|
|
||||||
}
|
|
||||||
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
||||||
bmt.print_rank(
|
|
||||||
(
|
|
||||||
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
|
|
||||||
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
|
|
||||||
).format(
|
|
||||||
epoch,
|
|
||||||
iteration,
|
|
||||||
args.train_iters,
|
|
||||||
task_loss_map[args.task_name],
|
|
||||||
lr_scheduler.current_lr,
|
|
||||||
int(optim_manager.loss_scale),
|
|
||||||
avg_time,
|
|
||||||
input_length.float().mean() / args.max_length,
|
|
||||||
(targets >= 0).sum(-1).float().mean() / args.max_length,
|
|
||||||
grad_norm,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"| "
|
|
||||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
|
||||||
+ " |"
|
|
||||||
)
|
|
||||||
if iteration % args.inspect_iters == 0:
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
train_info["model_inspect"] = model_inspect
|
|
||||||
|
|
||||||
# save_folder_name = f"{args.save}{epoch}{iteration}"
|
|
||||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
|
|
||||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
|
||||||
# if bmt.rank() == 0:
|
|
||||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
|
||||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
|
|
||||||
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
|
|
||||||
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
|
|
||||||
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", grad_norm.item(), iteration)
|
|
||||||
|
|
||||||
# if iteration % 10 == 0:
|
|
||||||
# save_folder_name = f"{args.save}{epoch}"
|
|
||||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
|
||||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
|
||||||
# bmt.save(model, model_fname)
|
|
||||||
# if bmt.rank() == 0:
|
|
||||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
|
||||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
|
||||||
|
|
||||||
save_folder_name = f"{args.save}/epoch-{epoch}"
|
|
||||||
model_fname = os.path.join(save_folder_name, f"{args.save_name}.pt")
|
|
||||||
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
if args.delta_type == None or args.save_origin_model == True :
|
|
||||||
print("saving base model...")
|
|
||||||
bmt.save(model, model_fname)
|
|
||||||
if args.delta_type != None and bmt.rank() == 0:
|
|
||||||
print("saving delta model...")
|
|
||||||
torch.save(state_dict, os.path.join(save_folder_name, f"{args.save_name}-delta.pt"))
|
|
||||||
if bmt.rank() == 0:
|
|
||||||
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
|
||||||
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = initialize()
|
|
||||||
# To Be Specified
|
|
||||||
bin_file = args.dataset
|
|
||||||
bmt.print_rank(f"dataset: {bin_file}")
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
dataloader = FinetuneDataset(
|
|
||||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
|
||||||
)
|
|
||||||
bmt.print_rank(f"#batch: {len(dataloader)}")
|
|
||||||
total_steps = len(dataloader) * args.epoch
|
|
||||||
|
|
||||||
setattr(args, 'train_iters', int(total_steps))
|
|
||||||
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
|
|
||||||
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
|
|
||||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
|
||||||
|
|
||||||
bmt.print_rank("finish loading")
|
|
||||||
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,63 +0,0 @@
|
||||||
#! /bin/bash
|
|
||||||
#SBATCH --partition=gpu3-1
|
|
||||||
#SBATCH --nodes=1
|
|
||||||
#SBATCH --ntasks-per-node=8
|
|
||||||
#SBATCH --gres=gpu:8
|
|
||||||
#SBATCH --cpus-per-task=8
|
|
||||||
|
|
||||||
export MASTER_ADDR=`hostname`
|
|
||||||
export MASTER_PORT=12345
|
|
||||||
echo $MASTER_ADDR
|
|
||||||
|
|
||||||
CPM_PATH="./9G-Train"
|
|
||||||
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
|
|
||||||
EXP_PATH=./exp
|
|
||||||
mkdir -p $EXP_PATH
|
|
||||||
MODEL_NAME="cpm9g-11b-sft"
|
|
||||||
|
|
||||||
OPTS=""
|
|
||||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
|
||||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
|
||||||
|
|
||||||
OPTS+=" --train-iters 10000"
|
|
||||||
OPTS+=" --inspect-iters 200"
|
|
||||||
OPTS+=" --warmup-iters 500"
|
|
||||||
|
|
||||||
OPTS+=" --lr-decay-style cosine"
|
|
||||||
OPTS+=" --weight-decay 0.1"
|
|
||||||
OPTS+=" --clip-grad 1.0"
|
|
||||||
OPTS+=" --loss-scale 1048576"
|
|
||||||
OPTS+=" --max-loss-scale 33554432"
|
|
||||||
OPTS+=" --min-loss-scale 1"
|
|
||||||
OPTS+=" --loss-scale-steps 32"
|
|
||||||
|
|
||||||
OPTS+=" --offload"
|
|
||||||
OPTS+=" --batch-size 4"
|
|
||||||
OPTS+=" --max-length 4096"
|
|
||||||
OPTS+=" --lr 2e-5"
|
|
||||||
OPTS+=" --start-step 0"
|
|
||||||
OPTS+=" --epoch 8"
|
|
||||||
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
|
|
||||||
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
|
|
||||||
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
|
|
||||||
OPTS+=" --save ${EXP_PATH}/checkpoints"
|
|
||||||
OPTS+=" --save-name ${MODEL_NAME}"
|
|
||||||
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
|
|
||||||
# OPTS+=" --flash triton"
|
|
||||||
# OPTS+=" --flash cuda"
|
|
||||||
# OPTS+=" --load-grad"
|
|
||||||
|
|
||||||
OPTS+=" --delta-tuning" #开启delta-tuning
|
|
||||||
OPTS+=" --delta-type lora" #目前仅支持lora
|
|
||||||
OPTS+=" --lora-r 64" #lora矩阵的维度,默认为8
|
|
||||||
OPTS+=" --lora-dropout 0.05" #默认为0
|
|
||||||
OPTS+=" --lora-alpha 64" #lora对原模型的影响比例,默认为8
|
|
||||||
OPTS+=" --lora-layer project_q project_v project_k w_0 w_1 w_out" #参与lora的线性层,默认为project_q,project_k
|
|
||||||
OPTS+=" --save-origin-model" #是否在每个epoch储存基座模型
|
|
||||||
|
|
||||||
OPTS+=" $@"
|
|
||||||
|
|
||||||
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g_delta.py ${OPTS}"
|
|
||||||
|
|
||||||
echo "${CMD}"
|
|
||||||
$CMD
|
|
|
@ -1,5 +0,0 @@
|
||||||
from .models import CPM9G
|
|
||||||
from .models import CPM9GConfig
|
|
||||||
|
|
||||||
from .tokenizers import CPM9GTokenizer
|
|
||||||
from .training_tasks import MixedDataset
|
|
|
@ -1,16 +0,0 @@
|
||||||
{
|
|
||||||
"vocab_size": 119696,
|
|
||||||
"dropout_p": 0.0,
|
|
||||||
"eps": 1e-05,
|
|
||||||
"half": true,
|
|
||||||
"use_flash_attn": false,
|
|
||||||
"flash_attn_mask_shape": "1d",
|
|
||||||
"dim_model": 4096,
|
|
||||||
"dim_ff": 12288,
|
|
||||||
"dim_head": 128,
|
|
||||||
"num_heads": 32,
|
|
||||||
"num_kv_heads": 32,
|
|
||||||
"num_layers": 32,
|
|
||||||
"activate_fn": "silu",
|
|
||||||
"scale": false
|
|
||||||
}
|
|
|
@ -1,658 +0,0 @@
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ...generation import apply_repetition_penalty
|
|
||||||
from ...generation import BeamHypotheses
|
|
||||||
from ...generation import top_k_top_p_filtering
|
|
||||||
|
|
||||||
from ...utils import pad
|
|
||||||
from ..models import CPM9GTorch
|
|
||||||
from ..tokenizers.cpm9g import CPM9GTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GGeneration:
|
|
||||||
def __init__(
|
|
||||||
self, model: CPM9GTorch, tokenizer: CPM9GTokenizer, max_in_len=1024, use_nbce: bool = False
|
|
||||||
):
|
|
||||||
model.eval()
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.max_in_len = max_in_len
|
|
||||||
|
|
||||||
def _convert_to_tensors(self, data: Any):
|
|
||||||
input_ids = self.tokenizer.encode(
|
|
||||||
data["input"]
|
|
||||||
) # [self.tokenizer.bos_token_id] + self.tokenizer.encode(data["input"])
|
|
||||||
model_input = {}
|
|
||||||
|
|
||||||
model_input["input_ids"] = torch.tensor(input_ids[: self.max_in_len], dtype=torch.int32).unsqueeze(0)
|
|
||||||
model_input["context"] = torch.zeros(
|
|
||||||
(model_input["input_ids"].shape[0], model_input["input_ids"].shape[1]), dtype=torch.int16
|
|
||||||
)
|
|
||||||
model_input["span"] = torch.ones((model_input["input_ids"].shape[1],), dtype=torch.int16).unsqueeze(0)
|
|
||||||
model_input["length"] = torch.tensor([model_input["input_ids"].shape[1]], dtype=torch.int16).unsqueeze(0)
|
|
||||||
return model_input
|
|
||||||
|
|
||||||
def _process_list(self, data_list: List[Any]):
|
|
||||||
input_tensors = list(map(self._convert_to_tensors, data_list))
|
|
||||||
keys = set(input_tensors[0].keys())
|
|
||||||
padded = {}
|
|
||||||
for key in keys:
|
|
||||||
padded[key] = pad(input_tensors, key, padding_side="left").cuda()
|
|
||||||
return padded
|
|
||||||
|
|
||||||
def generate(self, data_list, **kwargs):
|
|
||||||
origin_data_list = data_list.copy()
|
|
||||||
model_inputs = self._process_list(data_list)
|
|
||||||
with torch.inference_mode():
|
|
||||||
result_ids = self._decode(model_inputs, **kwargs)
|
|
||||||
|
|
||||||
return result_ids
|
|
||||||
|
|
||||||
def _decode(self, model_inputs, **kwargs):
|
|
||||||
raise NotImplementedError("_decode is not implemented.")
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GBeamSearch(CPM9GGeneration):
|
|
||||||
def _decode(
|
|
||||||
self,
|
|
||||||
model_inputs,
|
|
||||||
beam_size=4,
|
|
||||||
max_length=100,
|
|
||||||
repetition_penalty=1.2,
|
|
||||||
repetition_window=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Beam search
|
|
||||||
Args:
|
|
||||||
model_inputs (dict): input ids.
|
|
||||||
beam_size (int, optional, defaults to 3): beam size of beam search.
|
|
||||||
generate_length (int, optional, defaults to 100): maximum generation length.
|
|
||||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
|
||||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
|
||||||
""" # noqa: E501
|
|
||||||
# generate_length + 1 for EOS token
|
|
||||||
max_length += 1
|
|
||||||
# expand dimmension
|
|
||||||
batch_size = model_inputs["input_ids"].size(0)
|
|
||||||
input: torch.Tensor = (
|
|
||||||
model_inputs["input_ids"]
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size, -1)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size * beam_size, -1)
|
|
||||||
)
|
|
||||||
length = (
|
|
||||||
model_inputs["length"]
|
|
||||||
.squeeze(1)
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size)
|
|
||||||
.contiguous()
|
|
||||||
.view(
|
|
||||||
batch_size * beam_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
span: torch.Tensor = (
|
|
||||||
model_inputs["span"]
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size, -1)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size * beam_size, -1)
|
|
||||||
)
|
|
||||||
context: torch.Tensor = (
|
|
||||||
model_inputs["context"]
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size, -1)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size * beam_size, -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
done = [False for _ in range(batch_size)]
|
|
||||||
|
|
||||||
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input.device)
|
|
||||||
beam_scores[:, 1:] = -1e9
|
|
||||||
beam_scores = beam_scores.view(-1)
|
|
||||||
|
|
||||||
# generated hypotheses
|
|
||||||
generated_hyps = [
|
|
||||||
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
pred_start_index = input.size(-1)
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
for i in range(max_length + 1):
|
|
||||||
if i == 0:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input,
|
|
||||||
context=context,
|
|
||||||
span=span,
|
|
||||||
length=length,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input[:, -1:],
|
|
||||||
context=context,
|
|
||||||
span=span,
|
|
||||||
length=length,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
# skip all steps when we are done with each sentence
|
|
||||||
if all(done):
|
|
||||||
break
|
|
||||||
|
|
||||||
# (batch * beam, seqlen, model_dim)
|
|
||||||
logits = logits[:, -1, :]
|
|
||||||
if i == 0:
|
|
||||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
|
||||||
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
|
||||||
apply_repetition_penalty(
|
|
||||||
logits,
|
|
||||||
batch_size,
|
|
||||||
beam_size,
|
|
||||||
input,
|
|
||||||
repetition_penalty,
|
|
||||||
pred_start_index,
|
|
||||||
input.size(-1) - 1,
|
|
||||||
repetition_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
scores = F.log_softmax(logits, dim=-1)
|
|
||||||
|
|
||||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
|
|
||||||
|
|
||||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
|
||||||
next_scores = next_scores.view(batch_size, -1) # (batch_size, beam_size * vocab_size)
|
|
||||||
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
|
|
||||||
|
|
||||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * beam_size)
|
|
||||||
next_batch_beam = []
|
|
||||||
|
|
||||||
for sent_id in range(batch_size):
|
|
||||||
# if we are done with this sentence
|
|
||||||
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
|
|
||||||
if done[sent_id]:
|
|
||||||
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
|
|
||||||
continue
|
|
||||||
|
|
||||||
# next sentence beam content
|
|
||||||
next_sent_beam = []
|
|
||||||
|
|
||||||
# next words for this sentence
|
|
||||||
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
|
||||||
# get beam and word IDs
|
|
||||||
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
|
|
||||||
word_id = idx % scores.size(-1)
|
|
||||||
|
|
||||||
# end of sentence, or next word
|
|
||||||
if word_id == self.tokenizer.bos_token_id or i == max_length:
|
|
||||||
generated_hyps[sent_id].add(
|
|
||||||
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
|
|
||||||
value.item(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
|
|
||||||
|
|
||||||
# the beam for next step is full
|
|
||||||
if len(next_sent_beam) == beam_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
# update next beam content
|
|
||||||
assert len(next_sent_beam) == 0 if i == max_length else beam_size
|
|
||||||
if len(next_sent_beam) == 0:
|
|
||||||
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
|
|
||||||
next_batch_beam.extend(next_sent_beam)
|
|
||||||
assert len(next_batch_beam) == beam_size * (sent_id + 1)
|
|
||||||
|
|
||||||
# we have reached the last step
|
|
||||||
if i == max_length:
|
|
||||||
break
|
|
||||||
|
|
||||||
# sanity check / prepare next batch
|
|
||||||
assert len(next_batch_beam) == batch_size * beam_size
|
|
||||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
|
||||||
beam_words = input.new([x[1] for x in next_batch_beam])
|
|
||||||
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
|
|
||||||
# re-order batch and internal states
|
|
||||||
input = input[beam_idx, :]
|
|
||||||
|
|
||||||
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
|
|
||||||
for key_value_layer in past_key_values["buffer"]:
|
|
||||||
if key_value_layer is not None:
|
|
||||||
key_value_layer[0] = key_value_layer[0][beam_idx]
|
|
||||||
key_value_layer[1] = key_value_layer[1][beam_idx]
|
|
||||||
|
|
||||||
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
|
|
||||||
context = torch.cat(
|
|
||||||
[context, context[:, -1:]],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
length += 1
|
|
||||||
|
|
||||||
span = torch.cat([span, span[:, -1:]], dim=-1)
|
|
||||||
|
|
||||||
# select the best hypotheses
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for i, hypotheses in enumerate(generated_hyps):
|
|
||||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
|
||||||
results.append(best_hyp)
|
|
||||||
|
|
||||||
result_text = list(map(self.tokenizer.decode, results))
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GRandomSampling(CPM9GGeneration):
|
|
||||||
def _decode(
|
|
||||||
self,
|
|
||||||
model_inputs,
|
|
||||||
max_length=100,
|
|
||||||
top_k=0,
|
|
||||||
top_p=0.9,
|
|
||||||
temperature=0.9,
|
|
||||||
repetition_penalty=1.0,
|
|
||||||
repetition_window=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Top-k and top-p sampling.
|
|
||||||
Args:
|
|
||||||
model_inputs (dict): input ids
|
|
||||||
generate_length (int, optional, defaults to 100): maximum generation length
|
|
||||||
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
|
|
||||||
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
|
|
||||||
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
|
|
||||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
|
||||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
|
||||||
""" # noqa: E501
|
|
||||||
# generate_length + 1 for EOS token
|
|
||||||
max_length += 1
|
|
||||||
|
|
||||||
input = model_inputs["input_ids"]
|
|
||||||
context = model_inputs["context"]
|
|
||||||
|
|
||||||
length = model_inputs["length"].squeeze(1)
|
|
||||||
span = model_inputs["span"]
|
|
||||||
batch_size = input.size(0)
|
|
||||||
|
|
||||||
pred_start_index = input.size(-1)
|
|
||||||
past_key_values = None
|
|
||||||
done = [False for _ in range(batch_size)]
|
|
||||||
results = [None for _ in range(batch_size)]
|
|
||||||
for i in range(max_length):
|
|
||||||
if i == 0:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input,
|
|
||||||
context=context,
|
|
||||||
length=length,
|
|
||||||
span=span,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input[:, -1:],
|
|
||||||
context=context,
|
|
||||||
length=length,
|
|
||||||
span=span,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = logits[:, -1, :]
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
|
||||||
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
|
||||||
|
|
||||||
apply_repetition_penalty(
|
|
||||||
logits,
|
|
||||||
batch_size,
|
|
||||||
1,
|
|
||||||
input,
|
|
||||||
repetition_penalty,
|
|
||||||
pred_start_index,
|
|
||||||
input.size(-1) - 1,
|
|
||||||
repetition_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = logits / temperature
|
|
||||||
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
||||||
|
|
||||||
probs = F.softmax(logits, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
for idx in range(batch_size):
|
|
||||||
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
|
|
||||||
done[idx] = True
|
|
||||||
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
|
|
||||||
|
|
||||||
if sum(done) == batch_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
# update input ids
|
|
||||||
input = torch.cat([input, next_token], dim=-1)
|
|
||||||
length += 1
|
|
||||||
|
|
||||||
context = torch.cat(
|
|
||||||
[context, context[:, -1:]],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
span = torch.cat(
|
|
||||||
[span, span[:, -1:]],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_text = list(map(self.tokenizer.decode, results))
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GBeamSearchNBCE(CPM9GGeneration):
|
|
||||||
def _decode(
|
|
||||||
self,
|
|
||||||
model_inputs,
|
|
||||||
beam_size=5,
|
|
||||||
max_length=100,
|
|
||||||
repetition_penalty=1.0,
|
|
||||||
repetition_window=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Beam search
|
|
||||||
Args:
|
|
||||||
model_inputs (dict): input ids.
|
|
||||||
beam_size (int, optional, defaults to 3): beam size of beam search.
|
|
||||||
generate_length (int, optional, defaults to 100): maximum generation length.
|
|
||||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
|
||||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
|
||||||
""" # noqa: E501
|
|
||||||
# generate_length + 1 for EOS token
|
|
||||||
max_length += 1
|
|
||||||
|
|
||||||
# expand dimmension
|
|
||||||
batch_size = model_inputs["input_ids"].size(0)
|
|
||||||
input: torch.Tensor = (
|
|
||||||
model_inputs["input_ids"]
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size, -1)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size * beam_size, -1)
|
|
||||||
)
|
|
||||||
length = (
|
|
||||||
model_inputs["length"]
|
|
||||||
.squeeze(1)
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size)
|
|
||||||
.contiguous()
|
|
||||||
.view(
|
|
||||||
batch_size * beam_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
span: torch.Tensor = (
|
|
||||||
model_inputs["span"]
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size, -1)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size * beam_size, -1)
|
|
||||||
)
|
|
||||||
context: torch.Tensor = (
|
|
||||||
model_inputs["context"]
|
|
||||||
.unsqueeze(1)
|
|
||||||
.expand(batch_size, beam_size, -1)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size * beam_size, -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
done = [False]
|
|
||||||
|
|
||||||
beam_scores = torch.zeros((1, beam_size), dtype=torch.float, device=input.device)
|
|
||||||
beam_scores[:, 1:] = -1e9
|
|
||||||
beam_scores = beam_scores.view(-1)
|
|
||||||
|
|
||||||
# generated hypotheses
|
|
||||||
generated_hyps = [
|
|
||||||
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(1)
|
|
||||||
]
|
|
||||||
|
|
||||||
pred_start_index = input.size(-1)
|
|
||||||
past_key_values = None
|
|
||||||
|
|
||||||
for i in range(max_length + 1):
|
|
||||||
if i == 0:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input,
|
|
||||||
context=context,
|
|
||||||
span=span,
|
|
||||||
length=length,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input[:, -1:],
|
|
||||||
context=context,
|
|
||||||
span=span,
|
|
||||||
length=length,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
# skip all steps when we are done with each sentence
|
|
||||||
|
|
||||||
if all(done):
|
|
||||||
break
|
|
||||||
|
|
||||||
# (batch * beam, seqlen, model_dim)
|
|
||||||
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
|
|
||||||
logits = NBCE(logits) # [vocab_size]
|
|
||||||
logits = logits.tile(beam_size, 1)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
|
||||||
apply_repetition_penalty(
|
|
||||||
logits,
|
|
||||||
1,
|
|
||||||
beam_size,
|
|
||||||
input,
|
|
||||||
repetition_penalty,
|
|
||||||
pred_start_index,
|
|
||||||
input.size(-1) - 1,
|
|
||||||
repetition_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
scores = F.log_softmax(logits, dim=-1)
|
|
||||||
|
|
||||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
|
|
||||||
|
|
||||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
|
||||||
next_scores = next_scores.view(1, -1) # (batch_size, beam_size * vocab_size)
|
|
||||||
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
|
|
||||||
|
|
||||||
assert next_scores.size() == next_words.size() == (1, 2 * beam_size)
|
|
||||||
next_batch_beam = []
|
|
||||||
|
|
||||||
for sent_id in range(1):
|
|
||||||
# if we are done with this sentence
|
|
||||||
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
|
|
||||||
if done[sent_id]:
|
|
||||||
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
|
|
||||||
continue
|
|
||||||
|
|
||||||
# next sentence beam content
|
|
||||||
next_sent_beam = []
|
|
||||||
|
|
||||||
# next words for this sentence
|
|
||||||
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
|
||||||
# get beam and word IDs
|
|
||||||
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
|
|
||||||
word_id = idx % scores.size(-1)
|
|
||||||
|
|
||||||
# end of sentence, or next word
|
|
||||||
if word_id == self.tokenizer.bos_token_id or i == max_length:
|
|
||||||
generated_hyps[sent_id].add(
|
|
||||||
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
|
|
||||||
value.item(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
|
|
||||||
|
|
||||||
# the beam for next step is full
|
|
||||||
if len(next_sent_beam) == beam_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
# update next beam content
|
|
||||||
assert len(next_sent_beam) == 0 if i == max_length else beam_size
|
|
||||||
if len(next_sent_beam) == 0:
|
|
||||||
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
|
|
||||||
next_batch_beam.extend(next_sent_beam)
|
|
||||||
assert len(next_batch_beam) == beam_size * (sent_id + 1)
|
|
||||||
|
|
||||||
# we have reached the last step
|
|
||||||
if i == max_length:
|
|
||||||
break
|
|
||||||
|
|
||||||
# sanity check / prepare next batch
|
|
||||||
assert len(next_batch_beam) == batch_size * beam_size
|
|
||||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
|
||||||
beam_words = input.new([x[1] for x in next_batch_beam])
|
|
||||||
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
|
|
||||||
beam_idx *= batch_size
|
|
||||||
# re-order batch and internal states
|
|
||||||
input = input[beam_idx, :]
|
|
||||||
|
|
||||||
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
|
|
||||||
for key_value_layer in past_key_values["buffer"]:
|
|
||||||
if key_value_layer is not None:
|
|
||||||
key_value_layer[0] = key_value_layer[0][beam_idx]
|
|
||||||
key_value_layer[1] = key_value_layer[1][beam_idx]
|
|
||||||
|
|
||||||
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
|
|
||||||
context = torch.cat(
|
|
||||||
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
length = past_key_values["buffer_length"]
|
|
||||||
length = torch.cat(
|
|
||||||
[length, torch.ones((length.size(0), 1), dtype=torch.int16, device=length.device)],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
span = torch.cat([span, span[:, -1:]], dim=-1)
|
|
||||||
|
|
||||||
# select the best hypotheses
|
|
||||||
results = []
|
|
||||||
for i, hypotheses in enumerate(generated_hyps):
|
|
||||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
|
||||||
results.append(best_hyp)
|
|
||||||
|
|
||||||
result_text = list(map(self.tokenizer.decode, results))
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GRandomSamplingNBCE(CPM9GGeneration):
|
|
||||||
def _decode(
|
|
||||||
self,
|
|
||||||
model_inputs,
|
|
||||||
max_length=100,
|
|
||||||
top_k=0,
|
|
||||||
top_p=0.9,
|
|
||||||
temperature=0.9,
|
|
||||||
repetition_penalty=1.0,
|
|
||||||
repetition_window=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Top-k and top-p sampling.
|
|
||||||
Args:
|
|
||||||
model_inputs (dict): input ids
|
|
||||||
generate_length (int, optional, defaults to 100): maximum generation length
|
|
||||||
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
|
|
||||||
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
|
|
||||||
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
|
|
||||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
|
||||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
|
||||||
""" # noqa: E501
|
|
||||||
# generate_length + 1 for EOS token
|
|
||||||
max_length += 1
|
|
||||||
|
|
||||||
input = model_inputs["input_ids"]
|
|
||||||
context = model_inputs["context"]
|
|
||||||
|
|
||||||
length = model_inputs["length"].squeeze(1)
|
|
||||||
span = model_inputs["span"]
|
|
||||||
batch_size = input.size(0)
|
|
||||||
|
|
||||||
pred_start_index = input.size(-1)
|
|
||||||
past_key_values = None
|
|
||||||
done = [False]
|
|
||||||
results = [None]
|
|
||||||
for i in range(max_length):
|
|
||||||
if i == 0:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input,
|
|
||||||
context=context,
|
|
||||||
length=length,
|
|
||||||
span=span,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits, _, past_key_values = self.model.inference(
|
|
||||||
input=input[:, -1:],
|
|
||||||
context=context,
|
|
||||||
length=length,
|
|
||||||
span=span,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
|
|
||||||
logits = NBCE(logits) # [vocab_size]
|
|
||||||
logits = logits[None]
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
|
||||||
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
|
||||||
|
|
||||||
apply_repetition_penalty(
|
|
||||||
logits,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
input,
|
|
||||||
repetition_penalty,
|
|
||||||
pred_start_index,
|
|
||||||
input.size(-1) - 1,
|
|
||||||
repetition_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = logits / temperature
|
|
||||||
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
||||||
|
|
||||||
probs = F.softmax(logits, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
for idx in range(1):
|
|
||||||
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
|
|
||||||
done[idx] = True
|
|
||||||
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
|
|
||||||
|
|
||||||
if sum(done) == 1:
|
|
||||||
break
|
|
||||||
next_token = next_token.tile(batch_size, 1)
|
|
||||||
# update input ids
|
|
||||||
input = torch.cat([input, next_token], dim=-1)
|
|
||||||
length = past_key_values["buffer_length"]
|
|
||||||
length = torch.cat(
|
|
||||||
[length, torch.ones((length.size(0), 1), dtype=torch.int32, device=length.device)],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
# length += 1
|
|
||||||
context = torch.cat(
|
|
||||||
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
span = torch.cat(
|
|
||||||
[span, torch.zeros((span.size(0), 1), dtype=torch.int32, device=span.device)],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_text = list(map(self.tokenizer.decode, results))
|
|
||||||
return result_text
|
|
|
@ -1,3 +0,0 @@
|
||||||
from .cpm9g import CPM9G
|
|
||||||
from .cpm9g import CPM9GConfig
|
|
||||||
from .cpm9g_torch import CPM9GTorch
|
|
|
@ -1,272 +0,0 @@
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from ...layers import Embedding
|
|
||||||
from ...layers import Encoder
|
|
||||||
from ...layers import RotaryEmbeddingESM
|
|
||||||
from ...utils import Config
|
|
||||||
from ...utils import gradient_shrink
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GInferenceState(TypedDict):
|
|
||||||
buffer_context: torch.Tensor
|
|
||||||
buffer_sample_ids: torch.Tensor
|
|
||||||
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GConfig(Config):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=32000,
|
|
||||||
dim_model=4096,
|
|
||||||
num_heads=32,
|
|
||||||
num_kv_heads=32,
|
|
||||||
dim_head=128,
|
|
||||||
dim_ff=11008,
|
|
||||||
num_layers=32,
|
|
||||||
dropout_p=0.0,
|
|
||||||
activate_fn="silu",
|
|
||||||
scale=True,
|
|
||||||
eps=1e-5,
|
|
||||||
half: bool = True,
|
|
||||||
bf16: bool = False,
|
|
||||||
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
|
||||||
use_flash_attn: bool = True,
|
|
||||||
flash_attn_mask_shape="1d",
|
|
||||||
flash_impl="cuda",
|
|
||||||
base=10000,
|
|
||||||
tp=0,
|
|
||||||
disabled_checkpoint=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.dim_model = dim_model
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
self.dim_head = dim_head
|
|
||||||
self.dim_ff = dim_ff
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.dropout_p = dropout_p
|
|
||||||
self.activate_fn = activate_fn
|
|
||||||
self.scale = scale
|
|
||||||
self.eps = eps
|
|
||||||
if half:
|
|
||||||
if bf16:
|
|
||||||
self.dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
self.dtype = torch.half
|
|
||||||
else:
|
|
||||||
self.dtype = torch.float
|
|
||||||
self.flash_impl = flash_impl
|
|
||||||
self.mask_modules = mask_modules
|
|
||||||
self.use_flash_attn = use_flash_attn
|
|
||||||
self.flash_attn_mask_shape = flash_attn_mask_shape
|
|
||||||
self.base = base
|
|
||||||
self.tp = tp
|
|
||||||
self.disabled_checkpoint = disabled_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9G(bmt.DistributedModule):
|
|
||||||
def __init__(self, config: CPM9GConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.encoder = Encoder(
|
|
||||||
num_layers=config.num_layers,
|
|
||||||
dim_model=config.dim_model,
|
|
||||||
dim_ff=config.dim_ff,
|
|
||||||
num_heads=config.num_heads,
|
|
||||||
num_kv_heads=config.num_kv_heads,
|
|
||||||
dim_head=config.dim_head,
|
|
||||||
activate_fn=config.activate_fn,
|
|
||||||
dtype=config.dtype,
|
|
||||||
eps=config.eps,
|
|
||||||
dropout_p=config.dropout_p,
|
|
||||||
scale=config.scale,
|
|
||||||
mask_modules=config.mask_modules,
|
|
||||||
use_flash_attn=config.use_flash_attn,
|
|
||||||
tp=config.tp,
|
|
||||||
disabled_checkpoint=config.disabled_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.input_embedding = Embedding(
|
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
embedding_size=config.dim_model,
|
|
||||||
scale=config.scale,
|
|
||||||
dtype=config.dtype,
|
|
||||||
init_std=0.02,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.position_bias = RotaryEmbeddingESM(
|
|
||||||
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.lm_head = Embedding(
|
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
embedding_size=config.dim_model,
|
|
||||||
scale=config.scale,
|
|
||||||
dtype=config.dtype,
|
|
||||||
init_std=0.02,
|
|
||||||
)
|
|
||||||
self.flash_impl = config.flash_impl
|
|
||||||
self.use_flash_attn = config.use_flash_attn
|
|
||||||
self.flash_attn_mask_shape = config.flash_attn_mask_shape
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor, # (batch, seqlen) int32
|
|
||||||
length: torch.Tensor = None, # (batch) int32
|
|
||||||
context: torch.Tensor = None, # (batch, seqlen) bool
|
|
||||||
span: torch.Tensor = None, # (batch, seqlen) int32
|
|
||||||
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
|
|
||||||
max_seqlen: int = None,
|
|
||||||
position_ids: torch.Tensor = None, # (batch, seqlen) int32
|
|
||||||
):
|
|
||||||
batch = input.size(0)
|
|
||||||
seqlen = input.size(1)
|
|
||||||
device = input.device
|
|
||||||
|
|
||||||
if length is not None and length.dim() == 1:
|
|
||||||
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
|
||||||
|
|
||||||
# processing masks and position bias bucket
|
|
||||||
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
|
|
||||||
with torch.no_grad():
|
|
||||||
# directional mask
|
|
||||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
|
|
||||||
-1, 1
|
|
||||||
)
|
|
||||||
# context mask
|
|
||||||
attention_mask = context[:, None, :] | (
|
|
||||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
|
||||||
)
|
|
||||||
# span mask
|
|
||||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
|
||||||
# length mask
|
|
||||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
|
||||||
else:
|
|
||||||
attention_mask = None
|
|
||||||
|
|
||||||
hidden_states = self.input_embedding(input)
|
|
||||||
|
|
||||||
if self.config.tp:
|
|
||||||
with torch.no_grad():
|
|
||||||
if length is not None:
|
|
||||||
length = bmt.distributed.all_gather(length, comm=bmt.config["tp_comm"]).flatten(0, 1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = bmt.distributed.all_gather(attention_mask, comm=bmt.config["tp_comm"]).flatten(
|
|
||||||
0, 1
|
|
||||||
)
|
|
||||||
if cu_seqlens is not None:
|
|
||||||
lens = bmt.distributed.all_gather(
|
|
||||||
torch.tensor([cu_seqlens.numel()]).cuda(), comm=bmt.config["tp_comm"]
|
|
||||||
)
|
|
||||||
mx = lens.max().item()
|
|
||||||
cu_seq = torch.zeros(mx).int().cuda()
|
|
||||||
cu_seq[: cu_seqlens.numel()] = cu_seqlens
|
|
||||||
all_seq = bmt.distributed.all_gather(cu_seq, comm=bmt.config["tp_comm"])
|
|
||||||
cu_seqlens = [0]
|
|
||||||
for i, l in enumerate(lens):
|
|
||||||
for j in range(1, l - 1):
|
|
||||||
cu_seqlens.append(all_seq[i][j] + i * seqlen)
|
|
||||||
cu_seqlens.append((i + 1) * seqlen)
|
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).cuda()
|
|
||||||
if max_seqlen is not None:
|
|
||||||
max_seqlen = bmt.distributed.all_reduce(
|
|
||||||
torch.tensor([max_seqlen]).cuda(), "max", comm=bmt.config["tp_comm"]
|
|
||||||
)[0].item()
|
|
||||||
if position_ids is not None:
|
|
||||||
position_ids = bmt.distributed.all_gather(position_ids, comm=bmt.config["tp_comm"]).flatten(0, 1)
|
|
||||||
|
|
||||||
if self.use_flash_attn:
|
|
||||||
if self.flash_attn_mask_shape == "1d":
|
|
||||||
hidden_states = self.encoder(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
position_bias=self.position_bias,
|
|
||||||
pos_bias_type="rotary",
|
|
||||||
length_mask=length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.flash_impl == "triton":
|
|
||||||
mask = attention_mask.unsqueeze(dim=1).contiguous()
|
|
||||||
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
|
|
||||||
attention_mask_bias[mask == False] -= torch.inf
|
|
||||||
else:
|
|
||||||
attention_mask_bias = None
|
|
||||||
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
|
|
||||||
hidden_states = self.encoder(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
position_bias=self.position_bias,
|
|
||||||
pos_bias_type="rotary",
|
|
||||||
length_mask=None,
|
|
||||||
attention_mask_bias=attention_mask_bias,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
position_ids=position_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = self.encoder(
|
|
||||||
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.lm_head.projection(hidden_states)
|
|
||||||
|
|
||||||
return logits, hidden_states
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor, # (batch, len_q) int32
|
|
||||||
length: torch.Tensor, # (batch) int32
|
|
||||||
context: torch.Tensor, # (batch, seqlen) int16
|
|
||||||
span: torch.Tensor, # (batch, seqlen) int32
|
|
||||||
past_key_values: Optional[CPM9GInferenceState] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
|
|
||||||
batch = input.size(0)
|
|
||||||
len_q = input.size(1)
|
|
||||||
len_buffer = 0
|
|
||||||
if past_key_values is None:
|
|
||||||
present_buffer = None
|
|
||||||
else:
|
|
||||||
present_buffer = past_key_values["buffer"]
|
|
||||||
len_buffer = present_buffer[0][0].shape[-2]
|
|
||||||
seqlen = len_buffer + len_q
|
|
||||||
with torch.no_grad():
|
|
||||||
device = input.device
|
|
||||||
if length.dim() == 1:
|
|
||||||
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
|
||||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
|
||||||
# context mask
|
|
||||||
attention_mask = context[:, None, :] | (
|
|
||||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
|
||||||
)
|
|
||||||
# span mask
|
|
||||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
|
||||||
# length mask
|
|
||||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
|
||||||
|
|
||||||
hidden_states = self.input_embedding(input)
|
|
||||||
|
|
||||||
hidden_states, present_key_values, _ = self.encoder(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask[:, len_buffer:],
|
|
||||||
position_bias=self.position_bias,
|
|
||||||
use_cache=True,
|
|
||||||
past_key_values=present_buffer,
|
|
||||||
pos_bias_type="rotary",
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.lm_head.projection(hidden_states)
|
|
||||||
|
|
||||||
return (
|
|
||||||
logits,
|
|
||||||
hidden_states,
|
|
||||||
{"buffer": present_key_values},
|
|
||||||
)
|
|
|
@ -1,186 +0,0 @@
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from ...native_layers import Embedding
|
|
||||||
from ...native_layers import Encoder
|
|
||||||
from ...native_layers import RotaryEmbeddingESM
|
|
||||||
from ...utils import Config
|
|
||||||
from ...utils import gradient_shrink
|
|
||||||
from .cpm9g import CPM9GConfig
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GInferenceState(TypedDict):
|
|
||||||
buffer_context: torch.Tensor
|
|
||||||
buffer_sample_ids: torch.Tensor
|
|
||||||
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GTorch(torch.nn.Module):
|
|
||||||
def __init__(self, config: CPM9GConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.encoder = Encoder(
|
|
||||||
num_layers=config.num_layers,
|
|
||||||
dim_model=config.dim_model,
|
|
||||||
dim_ff=config.dim_ff,
|
|
||||||
num_heads=config.num_heads,
|
|
||||||
num_kv_heads=config.num_kv_heads,
|
|
||||||
dim_head=config.dim_head,
|
|
||||||
activate_fn=config.activate_fn,
|
|
||||||
dtype=config.dtype,
|
|
||||||
eps=config.eps,
|
|
||||||
dropout_p=config.dropout_p,
|
|
||||||
scale=config.scale,
|
|
||||||
mask_modules=config.mask_modules,
|
|
||||||
use_flash_attn=config.use_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.input_embedding = Embedding(
|
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
embedding_size=config.dim_model,
|
|
||||||
scale=config.scale,
|
|
||||||
dtype=config.dtype,
|
|
||||||
init_std=0.02,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.position_bias = RotaryEmbeddingESM(
|
|
||||||
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.lm_head = Embedding(
|
|
||||||
vocab_size=config.vocab_size,
|
|
||||||
embedding_size=config.dim_model,
|
|
||||||
scale=config.scale,
|
|
||||||
dtype=config.dtype,
|
|
||||||
init_std=0.02,
|
|
||||||
)
|
|
||||||
self.flash_impl = False
|
|
||||||
self.use_flash_attn = False
|
|
||||||
self.flash_attn_mask_shape = "1d"
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor, # (batch, seqlen) int32
|
|
||||||
length: torch.Tensor = None, # (batch) int32
|
|
||||||
context: torch.Tensor = None, # (batch, seqlen) bool
|
|
||||||
span: torch.Tensor = None, # (batch, seqlen) int32
|
|
||||||
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
|
|
||||||
max_seqlen: int = None,
|
|
||||||
position_ids: torch.Tensor = None, # (batch, seqlen) int32
|
|
||||||
):
|
|
||||||
batch = input.size(0)
|
|
||||||
seqlen = input.size(1)
|
|
||||||
device = input.device
|
|
||||||
|
|
||||||
if length is not None and length.dim() == 1:
|
|
||||||
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
|
||||||
|
|
||||||
# processing masks and position bias bucket
|
|
||||||
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
|
|
||||||
with torch.no_grad():
|
|
||||||
# directional mask
|
|
||||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
|
|
||||||
-1, 1
|
|
||||||
)
|
|
||||||
# context mask
|
|
||||||
attention_mask = context[:, None, :] | (
|
|
||||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
|
||||||
)
|
|
||||||
# span mask
|
|
||||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
|
||||||
# length mask
|
|
||||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
|
||||||
|
|
||||||
hidden_states = self.input_embedding(input)
|
|
||||||
|
|
||||||
if self.use_flash_attn:
|
|
||||||
if self.flash_attn_mask_shape == "1d":
|
|
||||||
hidden_states = self.encoder(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
position_bias=self.position_bias,
|
|
||||||
pos_bias_type="rotary",
|
|
||||||
length_mask=length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.flash_impl == "triton":
|
|
||||||
mask = attention_mask.unsqueeze(dim=1).contiguous()
|
|
||||||
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
|
|
||||||
attention_mask_bias[mask == False] -= torch.inf
|
|
||||||
else:
|
|
||||||
attention_mask_bias = None
|
|
||||||
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
|
|
||||||
hidden_states = self.encoder(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=None,
|
|
||||||
position_bias=self.position_bias,
|
|
||||||
pos_bias_type="rotary",
|
|
||||||
length_mask=None,
|
|
||||||
attention_mask_bias=attention_mask_bias,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
position_ids=position_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = self.encoder(
|
|
||||||
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.lm_head.projection(hidden_states)
|
|
||||||
|
|
||||||
return logits, hidden_states
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor, # (batch, len_q) int32
|
|
||||||
length: torch.Tensor, # (batch) int32
|
|
||||||
context: torch.Tensor, # (batch, seqlen) int16
|
|
||||||
span: torch.Tensor, # (batch, seqlen) int32
|
|
||||||
past_key_values: Optional[CPM9GInferenceState] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
|
|
||||||
batch = input.size(0)
|
|
||||||
len_q = input.size(1)
|
|
||||||
len_buffer = 0
|
|
||||||
if past_key_values is None:
|
|
||||||
present_buffer = None
|
|
||||||
else:
|
|
||||||
present_buffer = past_key_values["buffer"]
|
|
||||||
len_buffer = present_buffer[0][0].shape[-2]
|
|
||||||
seqlen = len_buffer + len_q
|
|
||||||
with torch.no_grad():
|
|
||||||
device = input.device
|
|
||||||
if length.dim() == 1:
|
|
||||||
length = (torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) + length[:, None]) >= seqlen
|
|
||||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
|
||||||
# context mask
|
|
||||||
attention_mask = context[:, None, :] | (
|
|
||||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
|
||||||
)
|
|
||||||
# span mask
|
|
||||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
|
||||||
# length mask
|
|
||||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
|
||||||
|
|
||||||
hidden_states = self.input_embedding(input)
|
|
||||||
|
|
||||||
hidden_states, present_key_values, _ = self.encoder(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask[:, len_buffer:],
|
|
||||||
position_bias=self.position_bias,
|
|
||||||
use_cache=True,
|
|
||||||
past_key_values=present_buffer,
|
|
||||||
pos_bias_type="rotary",
|
|
||||||
)
|
|
||||||
|
|
||||||
logits = self.lm_head.projection(hidden_states)
|
|
||||||
|
|
||||||
return (
|
|
||||||
logits,
|
|
||||||
hidden_states,
|
|
||||||
{"buffer": present_key_values},
|
|
||||||
)
|
|
|
@ -1 +0,0 @@
|
||||||
from .cpm9g import CPM9GTokenizer
|
|
|
@ -1,2 +0,0 @@
|
||||||
from .pretrain import MixedDataset
|
|
||||||
from .finetune import FinetuneDataset
|
|
|
@ -1,87 +0,0 @@
|
||||||
import bmtrain as bmt
|
|
||||||
|
|
||||||
from ...dataset import SimpleDataset
|
|
||||||
from ..tokenizers import CPM9GTokenizer
|
|
||||||
from .pretrain import _MixedDatasetBatchPacker
|
|
||||||
from .pretrain import _MixedDatasetConfig
|
|
||||||
from .pretrain import CPM9GBatch
|
|
||||||
|
|
||||||
|
|
||||||
class FinetuneDataset:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dataset_path: str,
|
|
||||||
batch_size: int,
|
|
||||||
max_length: int,
|
|
||||||
tokenizer: CPM9GTokenizer,
|
|
||||||
unpad: bool = False,
|
|
||||||
task_name: str = "task",
|
|
||||||
drop_last: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self._world_size = bmt.world_size()
|
|
||||||
self._rank = bmt.rank()
|
|
||||||
self._batch_size = batch_size
|
|
||||||
self._unpad = unpad
|
|
||||||
self._max_length = max_length
|
|
||||||
|
|
||||||
self._packer = _MixedDatasetBatchPacker(batch_size * self._world_size, max_length, tokenizer, unpad)
|
|
||||||
self._drop_last = drop_last
|
|
||||||
|
|
||||||
ds = SimpleDataset(dataset_path, shuffle=False)
|
|
||||||
self._ds_cfg: _MixedDatasetConfig = {
|
|
||||||
"weight": 1.0,
|
|
||||||
"path": dataset_path,
|
|
||||||
"transforms": [],
|
|
||||||
"task_name": task_name,
|
|
||||||
"dataset_name": "finetune",
|
|
||||||
"incontext_weight": [1.0],
|
|
||||||
"lines": len(ds),
|
|
||||||
"dataset": ds,
|
|
||||||
}
|
|
||||||
self._nlines = len(ds)
|
|
||||||
self._nbytes = ds.get_bytes()
|
|
||||||
|
|
||||||
def __batch_iter(self):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
batch = self._packer.add_data(self._ds_cfg)
|
|
||||||
except EOFError:
|
|
||||||
break
|
|
||||||
if batch is None:
|
|
||||||
continue
|
|
||||||
yield batch
|
|
||||||
if len(self._packer) > 0:
|
|
||||||
batch = self._packer.pack_batch(force=True)
|
|
||||||
if not self._drop_last:
|
|
||||||
yield batch
|
|
||||||
self._ds_cfg["dataset"]._repeat_times = 0
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
if not self._unpad:
|
|
||||||
batch_st = self._batch_size * self._rank
|
|
||||||
batch_end = self._batch_size * (self._rank + 1)
|
|
||||||
for batch in self.__batch_iter():
|
|
||||||
batch_size = batch["inputs"].shape[0]
|
|
||||||
if batch_size <= batch_st:
|
|
||||||
yield None
|
|
||||||
else:
|
|
||||||
ret: CPM9GBatch = {
|
|
||||||
kw: val[batch_st:batch_end] # type: ignore
|
|
||||||
for kw, val in batch.items()
|
|
||||||
if kw not in ["task_names", "raw_data", "cu_seqlens", "max_seqlen"]
|
|
||||||
} # type: ignore
|
|
||||||
ret["task_names"] = batch["task_names"]
|
|
||||||
ret["raw_data"] = batch["raw_data"]
|
|
||||||
ret["cu_seqlens"] = batch["cu_seqlens"]
|
|
||||||
ret["max_seqlen"] = batch["max_seqlen"]
|
|
||||||
yield ret
|
|
||||||
else:
|
|
||||||
for batch in self.__batch_iter():
|
|
||||||
assert batch["inputs"].shape[0] == 1
|
|
||||||
yield batch
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
iter_count = int(self._nbytes // (3.6 * self._batch_size * self._world_size * self._max_length))
|
|
||||||
if not self._drop_last:
|
|
||||||
iter_count += 1
|
|
||||||
return iter_count
|
|
|
@ -1,736 +0,0 @@
|
||||||
import importlib.machinery
|
|
||||||
import importlib.util
|
|
||||||
import json
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
from collections import OrderedDict
|
|
||||||
from queue import Empty
|
|
||||||
from typing import Any
|
|
||||||
from typing import Callable
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Set
|
|
||||||
from typing import Tuple
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from numpy.typing import NDArray
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from ...dataset import DistributedDataset
|
|
||||||
from ..tokenizers import CPM9GTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class _MixedDatasetConfig(TypedDict):
|
|
||||||
weight: float
|
|
||||||
path: str
|
|
||||||
transforms: Union[List[Dict[str, Any]], str]
|
|
||||||
task_name: str
|
|
||||||
dataset_name: str
|
|
||||||
incontext_weight: List[float]
|
|
||||||
lines: int
|
|
||||||
dataset: DistributedDataset
|
|
||||||
|
|
||||||
|
|
||||||
CPM9GInputType = Union[str, Dict[str, "CPM9GInputType"]]
|
|
||||||
|
|
||||||
|
|
||||||
class _TransformFuncDict(TypedDict):
|
|
||||||
loader: importlib.machinery.SourceFileLoader
|
|
||||||
module: types.ModuleType
|
|
||||||
last_m: float
|
|
||||||
|
|
||||||
|
|
||||||
_TransformFunction = Callable[[CPM9GInputType, int, random.Random], CPM9GInputType]
|
|
||||||
|
|
||||||
|
|
||||||
class CPM9GBatch(TypedDict):
|
|
||||||
inputs: NDArray[np.int32]
|
|
||||||
length: NDArray[np.int32]
|
|
||||||
context: NDArray[np.bool_]
|
|
||||||
sample_ids: NDArray[np.int32]
|
|
||||||
spans: NDArray[np.int32]
|
|
||||||
target: NDArray[np.int32]
|
|
||||||
task_ids: NDArray[np.int32]
|
|
||||||
task_names: List[str]
|
|
||||||
raw_data: List[Any]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_data_to_id(tokenizer: CPM9GTokenizer, data: Any):
|
|
||||||
input_ids = tokenizer.encode(data["input"])
|
|
||||||
output_ids = tokenizer.encode(data["output"])
|
|
||||||
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
|
|
||||||
ids = np.array(ids, dtype=np.int32)
|
|
||||||
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
|
||||||
context[: len(input_ids) + 1] = 1
|
|
||||||
return ids, context
|
|
||||||
|
|
||||||
|
|
||||||
def _dataset_identity(c: _MixedDatasetConfig):
|
|
||||||
return "{}.{}".format(c["task_name"], c["dataset_name"])
|
|
||||||
|
|
||||||
|
|
||||||
class _MixedDatasetBatchPacker:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_size: int,
|
|
||||||
max_length: int,
|
|
||||||
tokenizer: CPM9GTokenizer,
|
|
||||||
unpad: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self._batch_size = batch_size
|
|
||||||
self._max_length = max_length
|
|
||||||
self._clip_length = max_length
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self._transform_func_table: Dict[str, _TransformFuncDict] = {}
|
|
||||||
self._unpad = unpad
|
|
||||||
if unpad:
|
|
||||||
self._max_length = max_length * batch_size
|
|
||||||
self._batch_size = 1
|
|
||||||
|
|
||||||
self._inputs: List[NDArray[np.int32]] = []
|
|
||||||
self._context: List[NDArray[np.int8]] = []
|
|
||||||
self._sample_ids: List[NDArray[np.int32]] = []
|
|
||||||
self._spans: List[List[int]] = []
|
|
||||||
self._task_ids: List[List[str]] = []
|
|
||||||
self._raw_data: List[List[Any]] = []
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._inputs)
|
|
||||||
|
|
||||||
def apply_transform(
|
|
||||||
self,
|
|
||||||
data: CPM9GInputType,
|
|
||||||
transform: Union[Dict[str, Any], Callable[[CPM9GInputType], CPM9GInputType], None],
|
|
||||||
) -> CPM9GInputType:
|
|
||||||
if transform is None:
|
|
||||||
return data
|
|
||||||
if not isinstance(transform, dict):
|
|
||||||
return transform(data) # transform function
|
|
||||||
|
|
||||||
mapping_list: List[Tuple[str, str]] = []
|
|
||||||
|
|
||||||
def _walk_transform_dict(data: Union[Dict[str, Any], str], prefix: str = ""):
|
|
||||||
if isinstance(data, dict):
|
|
||||||
for k, v in data.items():
|
|
||||||
if len(prefix) > 0:
|
|
||||||
_walk_transform_dict(v, prefix + "." + k)
|
|
||||||
else:
|
|
||||||
_walk_transform_dict(v, k)
|
|
||||||
else:
|
|
||||||
assert isinstance(data, str), "Invalid transform {}".format(data)
|
|
||||||
mapping_list.append((prefix, data))
|
|
||||||
|
|
||||||
_walk_transform_dict(transform)
|
|
||||||
|
|
||||||
expanded_mapping_list: List[Tuple[str, Any]] = []
|
|
||||||
|
|
||||||
def _expand_mapping(data: CPM9GInputType, stars: List[str], path: List[str], target: List[str]):
|
|
||||||
if len(path) == 0:
|
|
||||||
num_stars = 0
|
|
||||||
for it in target:
|
|
||||||
if it == "*":
|
|
||||||
num_stars += 1
|
|
||||||
if num_stars != len(stars):
|
|
||||||
raise ValueError("Invalid transform {}".format(".".join(target)))
|
|
||||||
|
|
||||||
nw_tgt = []
|
|
||||||
num_stars = 0
|
|
||||||
for it in target:
|
|
||||||
if it == "*":
|
|
||||||
nw_tgt.append(stars[num_stars])
|
|
||||||
num_stars += 1
|
|
||||||
else:
|
|
||||||
nw_tgt.append(it)
|
|
||||||
expanded_mapping_list.append((".".join(nw_tgt), data))
|
|
||||||
else:
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
raise ValueError("Invalid data {}".format(data))
|
|
||||||
if path[0] == "*":
|
|
||||||
for k, v in data.items():
|
|
||||||
_expand_mapping(v, stars + [k], path[1:], target)
|
|
||||||
else:
|
|
||||||
_expand_mapping(data[path[0]], stars, path[1:], target)
|
|
||||||
|
|
||||||
# expand mapping list
|
|
||||||
for tgt, src in mapping_list:
|
|
||||||
if src.startswith("$"):
|
|
||||||
# copy from src
|
|
||||||
_expand_mapping(data, [], src[1:].split("."), tgt.split("."))
|
|
||||||
else:
|
|
||||||
if "*" in tgt:
|
|
||||||
raise ValueError("Constant value is not allowed to have `*` in prefix")
|
|
||||||
expanded_mapping_list.append((tgt, src))
|
|
||||||
|
|
||||||
ret = {}
|
|
||||||
for tgt, val in expanded_mapping_list:
|
|
||||||
tgt = tgt.split(".")
|
|
||||||
cur = ret
|
|
||||||
while len(tgt) > 1:
|
|
||||||
cur = cur[tgt[0]]
|
|
||||||
tgt = tgt[1:]
|
|
||||||
cur[tgt[0]] = val
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def data_to_id(self, data: Any):
|
|
||||||
return convert_data_to_id(self.tokenizer, data)
|
|
||||||
|
|
||||||
def _ensure_transform_function(self, module_name: str, transform_script_path: str) -> _TransformFunction:
|
|
||||||
module_name = "cpm_live.transforms.{}".format(module_name)
|
|
||||||
if transform_script_path not in self._transform_func_table:
|
|
||||||
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
|
||||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
|
||||||
if spec is None:
|
|
||||||
raise RuntimeError("spec is none! {}".format(module_name))
|
|
||||||
mod = importlib.util.module_from_spec(spec)
|
|
||||||
self._transform_func_table[transform_script_path] = {
|
|
||||||
"loader": loader,
|
|
||||||
"module": mod,
|
|
||||||
"last_m": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
transform_script_info = self._transform_func_table[transform_script_path]
|
|
||||||
curr_m_time = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
|
||||||
if curr_m_time > transform_script_info["last_m"]:
|
|
||||||
transform_script_info["last_m"] = curr_m_time
|
|
||||||
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
|
||||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
|
||||||
if transform_func is None:
|
|
||||||
|
|
||||||
def _empty_transform_func(data: CPM9GInputType, num_sample: int, r: random.Random):
|
|
||||||
raise NotImplementedError("Transform func for dataset {} not implemented".format(module_name))
|
|
||||||
|
|
||||||
return _empty_transform_func
|
|
||||||
else:
|
|
||||||
return transform_func
|
|
||||||
|
|
||||||
def build_instance(self, config: _MixedDatasetConfig):
|
|
||||||
_sample_weight = np.array(config["incontext_weight"], dtype=np.float32)
|
|
||||||
_sample_weight = _sample_weight / _sample_weight.sum()
|
|
||||||
num_incontext = np.random.choice(_sample_weight.shape[0], p=_sample_weight)
|
|
||||||
ds = config["dataset"]
|
|
||||||
transforms = config["transforms"]
|
|
||||||
if isinstance(transforms, str):
|
|
||||||
if not os.path.exists(transforms):
|
|
||||||
raise RuntimeError("transform script {} file not exists".format(transforms))
|
|
||||||
# load transform script
|
|
||||||
transform_func = self._ensure_transform_function(_dataset_identity(config), transforms)
|
|
||||||
seed = random.random()
|
|
||||||
|
|
||||||
def _transform(data: CPM9GInputType):
|
|
||||||
r = random.Random(seed)
|
|
||||||
return transform_func(data, num_incontext, r)
|
|
||||||
|
|
||||||
transform = _transform
|
|
||||||
elif len(transforms) == 0:
|
|
||||||
transform = None
|
|
||||||
else:
|
|
||||||
transform = transforms[np.random.choice(len(transforms))]
|
|
||||||
|
|
||||||
raw_data = {}
|
|
||||||
while True:
|
|
||||||
inp = ds.read()
|
|
||||||
inp = self.apply_transform(inp, transform)
|
|
||||||
# if None, skip this one
|
|
||||||
if inp is None:
|
|
||||||
continue
|
|
||||||
input_ids, context = self.data_to_id(inp)
|
|
||||||
if input_ids.shape[0] > self._clip_length:
|
|
||||||
continue # too long
|
|
||||||
input_ids = input_ids[: self._clip_length]
|
|
||||||
context = context[: self._clip_length]
|
|
||||||
raw_data["input"] = inp
|
|
||||||
raw_data["samples"] = []
|
|
||||||
break
|
|
||||||
|
|
||||||
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
|
||||||
i = 0
|
|
||||||
while True:
|
|
||||||
if i == num_incontext or input_ids.shape[0] >= self._clip_length:
|
|
||||||
break # early break
|
|
||||||
sample = ds.read()
|
|
||||||
sample = self.apply_transform(sample, transform)
|
|
||||||
# if sample is None, skip this one
|
|
||||||
if sample is None:
|
|
||||||
continue
|
|
||||||
sample_input_ids, _ = self.data_to_id(sample)
|
|
||||||
if input_ids.shape[0] + sample_input_ids.shape[0] > self._clip_length:
|
|
||||||
break # too long, break
|
|
||||||
raw_data["samples"].append(sample)
|
|
||||||
input_ids = np.concatenate([input_ids, sample_input_ids], axis=0)
|
|
||||||
context = np.concatenate([context, np.ones(sample_input_ids.shape, dtype=np.int8)], axis=0)
|
|
||||||
sample_ids = np.concatenate([sample_ids, np.full(sample_input_ids.shape, i + 1, dtype=np.int32)], axis=0)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return (
|
|
||||||
input_ids,
|
|
||||||
context,
|
|
||||||
sample_ids,
|
|
||||||
raw_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def pack_batch(self, force: bool = False, unilm=False) -> CPM9GBatch:
|
|
||||||
# pack batch
|
|
||||||
if len(self._inputs) < self._batch_size:
|
|
||||||
if not force:
|
|
||||||
raise RuntimeError("Batch insufficient")
|
|
||||||
batch_size = len(self._inputs)
|
|
||||||
else:
|
|
||||||
batch_size = self._batch_size
|
|
||||||
max_length = self._max_length # self._spans[0][-1] if self._unpad else self._max_length
|
|
||||||
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
||||||
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
|
||||||
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
||||||
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
|
||||||
|
|
||||||
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
||||||
length = np.zeros((batch_size,), dtype=np.int32)
|
|
||||||
task_ids = np.full((batch_size, max_length), -1, dtype=np.int32)
|
|
||||||
if self._spans[0][-1] != max_length:
|
|
||||||
cu_seqlens = np.array([0] + self._spans[0] + [max_length], dtype=np.int32)
|
|
||||||
else:
|
|
||||||
cu_seqlens = np.array([0] + self._spans[0], dtype=np.int32)
|
|
||||||
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
|
|
||||||
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
|
||||||
|
|
||||||
all_task_names: Set[str] = set()
|
|
||||||
for i in range(batch_size):
|
|
||||||
for task_name in self._task_ids[i]:
|
|
||||||
all_task_names.add(task_name)
|
|
||||||
task_names: List[str] = list(all_task_names)
|
|
||||||
task_name_to_id = {name: i for i, name in enumerate(task_names)}
|
|
||||||
|
|
||||||
raw_data_list: List[Any] = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
instance_length = self._inputs[i].shape[0]
|
|
||||||
inputs[i, :instance_length] = self._inputs[i]
|
|
||||||
sample_ids[i, :instance_length] = self._sample_ids[i]
|
|
||||||
if unilm:
|
|
||||||
context[i, :instance_length] = self._context[i]
|
|
||||||
|
|
||||||
span_begin = 0
|
|
||||||
for span_id, (span_end, task_name) in enumerate(zip(self._spans[i], self._task_ids[i])):
|
|
||||||
spans[i, span_begin:span_end] = span_id
|
|
||||||
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
|
|
||||||
task_ids[i, span_begin:span_end] = task_name_to_id[task_name]
|
|
||||||
span_begin = span_end
|
|
||||||
length[i] = instance_length
|
|
||||||
raw_data_list.extend(self._raw_data[i])
|
|
||||||
|
|
||||||
for j in range(instance_length):
|
|
||||||
if j > 1 and self._context[i][j] == 0:
|
|
||||||
if self._inputs[i][j] != self.tokenizer.bos_id and self._inputs[i][j - 1] != self.tokenizer.eos_id:
|
|
||||||
tgt[i, j - 1] = self._inputs[i][j]
|
|
||||||
|
|
||||||
self._inputs = self._inputs[batch_size:]
|
|
||||||
self._context = self._context[batch_size:]
|
|
||||||
self._sample_ids = self._sample_ids[batch_size:]
|
|
||||||
self._spans = self._spans[batch_size:]
|
|
||||||
self._task_ids = self._task_ids[batch_size:]
|
|
||||||
self._raw_data = self._raw_data[batch_size:]
|
|
||||||
return {
|
|
||||||
"inputs": inputs,
|
|
||||||
"length": length,
|
|
||||||
"context": context > 0,
|
|
||||||
"sample_ids": sample_ids,
|
|
||||||
"spans": spans,
|
|
||||||
"cu_seqlens": cu_seqlens,
|
|
||||||
"max_seqlen": max_seqlen,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"target": tgt,
|
|
||||||
"task_ids": task_ids,
|
|
||||||
"task_names": task_names,
|
|
||||||
"raw_data": raw_data_list,
|
|
||||||
}
|
|
||||||
|
|
||||||
def add_data(self, config: _MixedDatasetConfig) -> Optional[CPM9GBatch]:
|
|
||||||
(
|
|
||||||
input_ids,
|
|
||||||
context,
|
|
||||||
sample_ids,
|
|
||||||
raw_data,
|
|
||||||
) = self.build_instance(config)
|
|
||||||
|
|
||||||
# add to batch
|
|
||||||
best_fit: Union[None, int] = None
|
|
||||||
best_fit_space: Union[None, int] = None
|
|
||||||
for i in range(len(self._inputs)):
|
|
||||||
space = self._max_length - self._inputs[i].shape[0]
|
|
||||||
if input_ids.shape[0] <= space:
|
|
||||||
if best_fit_space is None:
|
|
||||||
best_fit = i
|
|
||||||
best_fit_space = space
|
|
||||||
elif best_fit_space > space:
|
|
||||||
best_fit = i
|
|
||||||
best_fit_space = space
|
|
||||||
if best_fit is None:
|
|
||||||
# add a new instance
|
|
||||||
self._inputs.append(input_ids)
|
|
||||||
self._context.append(context)
|
|
||||||
self._sample_ids.append(sample_ids)
|
|
||||||
self._spans.append([input_ids.shape[0]])
|
|
||||||
self._task_ids.append([config["task_name"]])
|
|
||||||
self._raw_data.append([raw_data])
|
|
||||||
else:
|
|
||||||
# add to existing instance
|
|
||||||
self._inputs[best_fit] = np.concatenate([self._inputs[best_fit], input_ids], axis=0)
|
|
||||||
self._context[best_fit] = np.concatenate([self._context[best_fit], context], axis=0)
|
|
||||||
self._sample_ids[best_fit] = np.concatenate([self._sample_ids[best_fit], sample_ids], axis=0)
|
|
||||||
self._spans[best_fit].append(self._inputs[best_fit].shape[0])
|
|
||||||
self._task_ids[best_fit].append(config["task_name"])
|
|
||||||
self._raw_data[best_fit].append(raw_data)
|
|
||||||
|
|
||||||
if len(self._inputs) > self._batch_size:
|
|
||||||
return self.pack_batch()
|
|
||||||
else:
|
|
||||||
return None # not ready
|
|
||||||
|
|
||||||
|
|
||||||
class _MixedDatasetConfigMananger:
|
|
||||||
def __init__(self, config_path: str) -> None:
|
|
||||||
self._config_path: str = config_path
|
|
||||||
self._config: Union[List[_MixedDatasetConfig], None] = None
|
|
||||||
self._last_m = 0
|
|
||||||
|
|
||||||
def changed(self):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
m_time = os.stat(self._config_path).st_mtime
|
|
||||||
if m_time > self._last_m:
|
|
||||||
# try to load new config
|
|
||||||
try:
|
|
||||||
self._config = json.load(open(self._config_path, "r", encoding="utf-8"))
|
|
||||||
except Exception:
|
|
||||||
# failed to load config
|
|
||||||
return False
|
|
||||||
# new config loaded
|
|
||||||
self._last_m = m_time
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
print("Error: reading info list! {}".format(self._config_path))
|
|
||||||
time.sleep(30)
|
|
||||||
|
|
||||||
def get_config(self) -> List[_MixedDatasetConfig]:
|
|
||||||
if self._config is None:
|
|
||||||
if not self.changed():
|
|
||||||
raise RuntimeError("Failed to load config")
|
|
||||||
if self._config is None:
|
|
||||||
raise RuntimeError("Failed to load config")
|
|
||||||
return self._config
|
|
||||||
|
|
||||||
|
|
||||||
def _mixed_dataset_process(
|
|
||||||
config_path: str,
|
|
||||||
q_cmd: multiprocessing.Queue,
|
|
||||||
q_cmd_out: multiprocessing.Queue,
|
|
||||||
q_data: multiprocessing.Queue,
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
packer: _MixedDatasetBatchPacker,
|
|
||||||
max_repeat_times: int = None,
|
|
||||||
random_state=None,
|
|
||||||
):
|
|
||||||
for _ in range(rank):
|
|
||||||
np.random.random()
|
|
||||||
random.setstate(random_state)
|
|
||||||
# ignore SIGINT
|
|
||||||
import signal
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
||||||
config_base_path = os.path.dirname(os.path.abspath(config_path))
|
|
||||||
|
|
||||||
# fix libgomp threading deadlock after fork(), use single-thread in data process.
|
|
||||||
# REF: https://github.com/pytorch/pytorch/issues/17199
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
|
|
||||||
def _convert_to_abs_path(transform_path: str):
|
|
||||||
if transform_path.startswith("/"):
|
|
||||||
return transform_path
|
|
||||||
else:
|
|
||||||
return os.path.join(config_base_path, transform_path)
|
|
||||||
|
|
||||||
def _build_sample_weights(config: List[_MixedDatasetConfig]):
|
|
||||||
if len(config) == 0:
|
|
||||||
return np.array([], dtype=np.float32)
|
|
||||||
weights = [c["weight"] * c["lines"] for c in config]
|
|
||||||
weights = np.array(weights, dtype=np.float32)
|
|
||||||
sm_weight = weights.sum()
|
|
||||||
if sm_weight > 0:
|
|
||||||
weights = weights / sm_weight
|
|
||||||
return weights
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Empty datasets")
|
|
||||||
|
|
||||||
cfg_mgr = _MixedDatasetConfigMananger(config_path)
|
|
||||||
config = cfg_mgr.get_config()
|
|
||||||
|
|
||||||
for c in config:
|
|
||||||
ds = DistributedDataset(
|
|
||||||
_convert_to_abs_path(c["path"]),
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
max_repeat_times=max_repeat_times,
|
|
||||||
)
|
|
||||||
|
|
||||||
c["lines"] = ds._nlines
|
|
||||||
c["dataset"] = ds
|
|
||||||
if "weight" not in c:
|
|
||||||
c["weight"] = 1.0
|
|
||||||
if "transforms" not in c:
|
|
||||||
c["transforms"] = []
|
|
||||||
elif isinstance(c["transforms"], str):
|
|
||||||
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
|
||||||
if "incontext_weight" not in c:
|
|
||||||
c["incontext_weight"] = [1.0]
|
|
||||||
|
|
||||||
weights = _build_sample_weights(config)
|
|
||||||
|
|
||||||
should_stop = False
|
|
||||||
should_start = False
|
|
||||||
|
|
||||||
while not should_stop:
|
|
||||||
# update config first
|
|
||||||
if cfg_mgr.changed():
|
|
||||||
path_ds_map: Dict[str, _MixedDatasetConfig] = {}
|
|
||||||
nw_path_set: Set[str] = set()
|
|
||||||
|
|
||||||
# load new config
|
|
||||||
nw_config = cfg_mgr.get_config()
|
|
||||||
|
|
||||||
# build path -> dataset map
|
|
||||||
for c in config:
|
|
||||||
path_ds_map[_dataset_identity(c)] = c
|
|
||||||
|
|
||||||
# add new datasets
|
|
||||||
for c in nw_config:
|
|
||||||
if _dataset_identity(c) in path_ds_map:
|
|
||||||
# update values only
|
|
||||||
if "weight" in c:
|
|
||||||
path_ds_map[_dataset_identity(c)]["weight"] = c["weight"]
|
|
||||||
if "transform" in c:
|
|
||||||
if isinstance(c["transforms"], str):
|
|
||||||
path_ds_map[_dataset_identity(c)]["transforms"] = _convert_to_abs_path(c["transforms"])
|
|
||||||
else:
|
|
||||||
path_ds_map[_dataset_identity(c)]["transforms"] = c["transforms"]
|
|
||||||
if "incontext_weight" in c:
|
|
||||||
path_ds_map[_dataset_identity(c)]["incontext_weight"] = c["incontext_weight"]
|
|
||||||
else:
|
|
||||||
# new dataset
|
|
||||||
ds = DistributedDataset(
|
|
||||||
_convert_to_abs_path(c["path"]),
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
max_repeat_times=max_repeat_times,
|
|
||||||
)
|
|
||||||
c["lines"] = ds._nlines
|
|
||||||
c["dataset"] = ds
|
|
||||||
if "weight" not in c:
|
|
||||||
c["weight"] = 1.0
|
|
||||||
if "transforms" not in c:
|
|
||||||
c["transforms"] = []
|
|
||||||
elif isinstance(c["transforms"], str):
|
|
||||||
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
|
||||||
if "incontext_weight" not in c:
|
|
||||||
c["incontext_weight"] = [1.0]
|
|
||||||
path_ds_map[_dataset_identity(c)] = c
|
|
||||||
nw_path_set.add(_dataset_identity(c))
|
|
||||||
|
|
||||||
# remove unused datasets
|
|
||||||
for c in config:
|
|
||||||
if _dataset_identity(c) not in nw_path_set:
|
|
||||||
del path_ds_map[_dataset_identity(c)]
|
|
||||||
|
|
||||||
config: List[_MixedDatasetConfig] = []
|
|
||||||
for c in nw_config:
|
|
||||||
config.append(path_ds_map[_dataset_identity(c)])
|
|
||||||
del path_ds_map
|
|
||||||
del nw_path_set
|
|
||||||
del nw_config
|
|
||||||
|
|
||||||
weights = _build_sample_weights(config)
|
|
||||||
|
|
||||||
# get cmds
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
cmd = q_cmd.get_nowait()
|
|
||||||
except Empty:
|
|
||||||
break
|
|
||||||
if cmd == "stop":
|
|
||||||
should_stop = True
|
|
||||||
q_cmd_out.put(True)
|
|
||||||
break
|
|
||||||
elif cmd == "state_dict":
|
|
||||||
ret = OrderedDict()
|
|
||||||
for c in config:
|
|
||||||
ds_name = _dataset_identity(c)
|
|
||||||
ret[ds_name] = c["dataset"]._state_dict()
|
|
||||||
q_cmd_out.put(ret)
|
|
||||||
elif cmd == "load_state_dict":
|
|
||||||
state_dict = q_cmd.get()
|
|
||||||
missing = []
|
|
||||||
for idx, c in enumerate(config):
|
|
||||||
ds_name = _dataset_identity(c)
|
|
||||||
print(f"loading {idx}/{len(config)} {ds_name}")
|
|
||||||
if ds_name in state_dict:
|
|
||||||
c["dataset"].load_state_dict(state_dict[ds_name], strict=False)
|
|
||||||
else:
|
|
||||||
# new dataset
|
|
||||||
missing.append(ds_name)
|
|
||||||
q_cmd_out.put(missing)
|
|
||||||
elif cmd == "start":
|
|
||||||
should_start = True
|
|
||||||
q_cmd_out.put(True)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unknown command: {}".format(cmd))
|
|
||||||
|
|
||||||
if should_stop:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not should_start:
|
|
||||||
# wait for start cmd
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(config) == 0:
|
|
||||||
# no dataset available
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if q_data.full():
|
|
||||||
# queue full
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# sample a dataset
|
|
||||||
ds_id: int = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
ds_id = np.random.choice(weights.shape[0], p=weights)
|
|
||||||
if config[ds_id]["dataset"]._nlines != config[ds_id]["lines"]:
|
|
||||||
# dataset size changed
|
|
||||||
for c in config:
|
|
||||||
c["lines"] = c["dataset"]._nlines
|
|
||||||
weights = _build_sample_weights(config)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
batch = packer.add_data(config[ds_id])
|
|
||||||
if batch is not None:
|
|
||||||
# new batch comming
|
|
||||||
q_data.put(batch)
|
|
||||||
|
|
||||||
# clean queue
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
q_data.get_nowait()
|
|
||||||
except Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
class MixedDataset:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config_path: str,
|
|
||||||
batch_size: int,
|
|
||||||
max_length: int,
|
|
||||||
tokenizer: CPM9GTokenizer,
|
|
||||||
unpad: bool = False,
|
|
||||||
max_repeat_times: int = None,
|
|
||||||
) -> None:
|
|
||||||
self._q_cmd = multiprocessing.Queue()
|
|
||||||
self._q_cmd_out = multiprocessing.Queue()
|
|
||||||
self._q_data = multiprocessing.Queue(maxsize=2)
|
|
||||||
self._packer = _MixedDatasetBatchPacker(batch_size, max_length, tokenizer, unpad)
|
|
||||||
self._p = multiprocessing.Process(
|
|
||||||
target=_mixed_dataset_process,
|
|
||||||
args=(
|
|
||||||
config_path,
|
|
||||||
self._q_cmd,
|
|
||||||
self._q_cmd_out,
|
|
||||||
self._q_data,
|
|
||||||
bmt.rank(),
|
|
||||||
bmt.world_size(),
|
|
||||||
self._packer,
|
|
||||||
max_repeat_times,
|
|
||||||
random.getstate(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._p.start()
|
|
||||||
self._closed = False
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
if not self._closed:
|
|
||||||
self._closed = True
|
|
||||||
self._q_cmd.put("stop")
|
|
||||||
assert self._q_cmd_out.get(), "Failed to stop process"
|
|
||||||
self._p.join()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def closed(self):
|
|
||||||
return self._closed
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self._q_cmd.put("start")
|
|
||||||
return self._q_cmd_out.get()
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
self._q_cmd.put("state_dict")
|
|
||||||
states = self._q_cmd_out.get()
|
|
||||||
if not isinstance(states, OrderedDict):
|
|
||||||
raise RuntimeError("Invalid state dict {}".format(states))
|
|
||||||
if bmt.world_size() == 1:
|
|
||||||
for val in states.values():
|
|
||||||
val["states"].unsqueeze_(0)
|
|
||||||
val["block"].unsqueeze_(0)
|
|
||||||
return states
|
|
||||||
|
|
||||||
ret = OrderedDict()
|
|
||||||
for k, v in states.items():
|
|
||||||
num_unused_block = v["states"].size(0)
|
|
||||||
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
|
||||||
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
|
|
||||||
if max_unused_blocks == 0:
|
|
||||||
max_unused_blocks = 1
|
|
||||||
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
|
||||||
gpu_states[:num_unused_block] = v["states"].cuda()
|
|
||||||
|
|
||||||
gpu_block = v["block"].cuda()
|
|
||||||
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
|
|
||||||
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
|
|
||||||
ret[k] = {"states": global_states, "block": global_block}
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def load_state_dict(self, data: OrderedDict, strict: bool = False):
|
|
||||||
self._q_cmd.put("load_state_dict")
|
|
||||||
self._q_cmd.put(data)
|
|
||||||
missing = self._q_cmd_out.get()
|
|
||||||
if strict:
|
|
||||||
if len(missing) > 0:
|
|
||||||
raise RuntimeError("Missing dataset state: {}".format(missing))
|
|
||||||
return missing
|
|
||||||
|
|
||||||
def get(self) -> CPM9GBatch:
|
|
||||||
ret: CPM9GBatch = self._q_data.get(timeout=300) # type: ignore
|
|
||||||
if not isinstance(ret, dict):
|
|
||||||
raise RuntimeError("Invalid data {}".format(ret))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
while True:
|
|
||||||
yield self.get()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if not self.closed:
|
|
||||||
try:
|
|
||||||
self.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,57 +0,0 @@
|
||||||
import itertools
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class ListDataset(torch.utils.data.Dataset):
|
|
||||||
"""
|
|
||||||
同时支持 map-style 和 iterable-style
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, data_list: List, distributed: bool = False, shuffle: bool = True, infinite: bool = False
|
|
||||||
) -> None:
|
|
||||||
super(ListDataset, self).__init__()
|
|
||||||
if distributed:
|
|
||||||
rank = bmt.rank()
|
|
||||||
world_size = bmt.world_size()
|
|
||||||
self.data_list = list(itertools.islice(data_list, rank, None, world_size))
|
|
||||||
else:
|
|
||||||
self.data_list = data_list
|
|
||||||
self.shuffle = shuffle
|
|
||||||
self.infinite = infinite
|
|
||||||
self.idx = 0
|
|
||||||
|
|
||||||
if shuffle:
|
|
||||||
self._shuffle()
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.idx >= len(self):
|
|
||||||
if self.infinite:
|
|
||||||
if self.shuffle:
|
|
||||||
self._shuffle()
|
|
||||||
self.idx = 0
|
|
||||||
else:
|
|
||||||
raise StopIteration
|
|
||||||
data = self.data_list[self.idx]
|
|
||||||
self.idx += 1
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.data_list[idx]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data_list)
|
|
||||||
|
|
||||||
def _shuffle(self):
|
|
||||||
random.shuffle(self.data_list)
|
|
||||||
|
|
||||||
def read(self):
|
|
||||||
return self.__next__()
|
|
|
@ -1,4 +0,0 @@
|
||||||
from .generation_utils import apply_repetition_penalty
|
|
||||||
from .generation_utils import BeamHypotheses
|
|
||||||
from .generation_utils import NBCE
|
|
||||||
from .generation_utils import top_k_top_p_filtering
|
|
|
@ -1,127 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
|
|
||||||
# This function has been mostly taken from huggingface conversational ai code at
|
|
||||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
# Remove all tokens with a probability less than the last token of the top-k
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
batch_size = logits.size()[0]
|
|
||||||
if top_p > 0.0:
|
|
||||||
logits = logits.view(batch_size, -1).contiguous()
|
|
||||||
for index in range(len(logits)):
|
|
||||||
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
|
|
||||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
# Remove tokens with cumulative probability above the threshold
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
# Shift the indices to the right to keep also the first token above the threshold
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
||||||
logits[index][indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
logits = logits.view(batch_size, -1).contiguous()
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty(
|
|
||||||
logits,
|
|
||||||
batch_size,
|
|
||||||
num_beams,
|
|
||||||
prev_output_tokens,
|
|
||||||
repetition_penalty,
|
|
||||||
start_idx=None,
|
|
||||||
end_idx=None,
|
|
||||||
window_size=None,
|
|
||||||
):
|
|
||||||
# only conduct repetition penalty for the output
|
|
||||||
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
|
|
||||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
|
||||||
for i in range(batch_size * num_beams):
|
|
||||||
if start_idx is None or end_idx is None:
|
|
||||||
output_tokens = prev_output_tokens[i].tolist()
|
|
||||||
else:
|
|
||||||
if end_idx >= start_idx:
|
|
||||||
if window_size:
|
|
||||||
output_tokens = prev_output_tokens[i][
|
|
||||||
max(start_idx, end_idx + 1 - window_size) : end_idx + 1
|
|
||||||
].tolist()
|
|
||||||
else:
|
|
||||||
output_tokens = prev_output_tokens[i][start_idx : end_idx + 1].tolist()
|
|
||||||
else:
|
|
||||||
output_tokens = []
|
|
||||||
for previous_token in set(output_tokens):
|
|
||||||
# if score < 0 then repetition penalty has to
|
|
||||||
# multiplied to reduce the previous token probability
|
|
||||||
if logits[i, previous_token] < 0:
|
|
||||||
logits[i, previous_token] *= repetition_penalty
|
|
||||||
else:
|
|
||||||
logits[i, previous_token] /= repetition_penalty
|
|
||||||
|
|
||||||
|
|
||||||
class BeamHypotheses:
|
|
||||||
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
|
|
||||||
"""
|
|
||||||
Initialize n-best list of hypotheses.
|
|
||||||
"""
|
|
||||||
self.max_len = max_len
|
|
||||||
self.length_penalty = length_penalty
|
|
||||||
self.early_stopping = early_stopping
|
|
||||||
self.n_hyp = n_hyp
|
|
||||||
self.hyp = []
|
|
||||||
self.worst_score = 1e9
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""
|
|
||||||
Number of hypotheses in the list.
|
|
||||||
"""
|
|
||||||
return len(self.hyp)
|
|
||||||
|
|
||||||
def add(self, hyp, sum_logprobs):
|
|
||||||
"""
|
|
||||||
Add a new hypothesis to the list.
|
|
||||||
"""
|
|
||||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
|
||||||
|
|
||||||
if len(self) < self.n_hyp or score > self.worst_score:
|
|
||||||
self.hyp.append((score, hyp))
|
|
||||||
if len(self) > self.n_hyp:
|
|
||||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
|
||||||
del self.hyp[sorted_scores[0][1]]
|
|
||||||
self.worst_score = sorted_scores[1][0]
|
|
||||||
else:
|
|
||||||
self.worst_score = min(score, self.worst_score)
|
|
||||||
|
|
||||||
def is_done(self, best_sum_logprobs, cur_len):
|
|
||||||
"""
|
|
||||||
If there are enough hypotheses and that none of the hypotheses being generated
|
|
||||||
can become better than the worst one in the heap, then we are done with this sentence.
|
|
||||||
"""
|
|
||||||
if len(self) < self.n_hyp:
|
|
||||||
return False
|
|
||||||
elif self.early_stopping:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
|
|
||||||
|
|
||||||
|
|
||||||
def NBCE(logits):
|
|
||||||
"""
|
|
||||||
Naive Bayes-based Context Extension
|
|
||||||
blog: https://www.kexue.fm/archives/9617
|
|
||||||
"""
|
|
||||||
beta = 0.25
|
|
||||||
logits = logits[:, -1] # bsh -> bh
|
|
||||||
logits = logits - logits.logsumexp(dim=-1, keepdims=True)
|
|
||||||
k = (logits.exp() * logits).sum(dim=-1)[1:].argmax() + 1
|
|
||||||
logits_max = logits[k]
|
|
||||||
logits_uncond = logits[0]
|
|
||||||
logits = (1 + beta) * logits_max - beta * logits_uncond
|
|
||||||
return logits
|
|
|
@ -1,12 +0,0 @@
|
||||||
from .attention import Attention
|
|
||||||
from .blocks import TransformerBlock
|
|
||||||
from .embedding import Embedding
|
|
||||||
from .embedding import EmbeddingExt
|
|
||||||
from .feedforward import FeedForward
|
|
||||||
from .layernorm import LayerNorm
|
|
||||||
from .linear import Linear
|
|
||||||
from .position_embedding import BucketPositionBias
|
|
||||||
from .position_embedding import RotaryEmbedding
|
|
||||||
from .position_embedding import RotaryEmbeddingESM
|
|
||||||
from .position_embedding import SegmentPositionEmbedding
|
|
||||||
from .transformer import Encoder
|
|
|
@ -1,134 +0,0 @@
|
||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
from .linear import Linear
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_model: int,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
dim_head: int,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
dropout_p: Optional[float] = None,
|
|
||||||
use_flash_attn: bool = False,
|
|
||||||
scale: bool = True,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim_model = dim_model
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.dim_head = dim_head
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
self.head_groups = num_heads // num_kv_heads
|
|
||||||
|
|
||||||
self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype, scale=scale)
|
|
||||||
self.project_k = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
|
|
||||||
self.project_v = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
|
|
||||||
|
|
||||||
self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale)
|
|
||||||
|
|
||||||
self.softmax = torch.nn.Softmax(dim=-1)
|
|
||||||
|
|
||||||
if dropout_p is not None:
|
|
||||||
self.dropout = torch.nn.Dropout(p=dropout_p)
|
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
|
|
||||||
# if use_flash_attn:
|
|
||||||
# self.core_attention_flash = FlashSelfAttention(causal=False, attention_dropout=0.0)
|
|
||||||
# self.use_flash_attn = use_flash_attn
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_q: torch.Tensor,
|
|
||||||
hidden_kv: torch.Tensor,
|
|
||||||
attention_mask: torch.BoolTensor,
|
|
||||||
position_bias: torch.Tensor,
|
|
||||||
use_cache: bool = False,
|
|
||||||
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
pos_bias_type: Optional[str] = "relative",
|
|
||||||
length_mask: Optional[torch.Tensor] = None,
|
|
||||||
context_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
|
|
||||||
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
|
|
||||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
|
|
||||||
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
|
|
||||||
Return:
|
|
||||||
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
batch_size = hidden_q.size(0)
|
|
||||||
len_q = hidden_q.size(1)
|
|
||||||
len_k = hidden_kv.size(1)
|
|
||||||
|
|
||||||
h_q = self.project_q(hidden_q)
|
|
||||||
h_k = self.project_k(hidden_kv)
|
|
||||||
h_v = self.project_v(hidden_kv)
|
|
||||||
|
|
||||||
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
|
|
||||||
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
|
|
||||||
|
|
||||||
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
|
||||||
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
|
||||||
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if pos_bias_type == "rotary":
|
|
||||||
# b h s d
|
|
||||||
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
|
|
||||||
|
|
||||||
if past_kv is not None:
|
|
||||||
h_k = torch.cat([past_kv[0], h_k], dim=-2)
|
|
||||||
h_v = torch.cat([past_kv[1], h_v], dim=-2)
|
|
||||||
len_k = h_k.size(-2)
|
|
||||||
|
|
||||||
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
|
|
||||||
if self.head_groups == 1:
|
|
||||||
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
|
|
||||||
else:
|
|
||||||
score = torch.matmul(
|
|
||||||
h_q.reshape(batch_size, self.num_kv_heads, self.head_groups * len_q, self.dim_head),
|
|
||||||
h_k.transpose(-1, -2),
|
|
||||||
).view(batch_size, self.num_heads, len_q, len_k)
|
|
||||||
|
|
||||||
if pos_bias_type == "relative":
|
|
||||||
score = score + position_bias
|
|
||||||
score = torch.masked_fill(
|
|
||||||
score,
|
|
||||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
|
||||||
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
score = self.softmax(score)
|
|
||||||
|
|
||||||
score = torch.masked_fill(
|
|
||||||
score,
|
|
||||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
|
||||||
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.dropout is not None:
|
|
||||||
score = self.dropout(score)
|
|
||||||
|
|
||||||
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
|
|
||||||
score = torch.matmul(score.view(batch_size, self.num_kv_heads, self.head_groups * len_q, len_k), h_v).view(
|
|
||||||
batch_size, self.num_heads, len_q, self.dim_head
|
|
||||||
)
|
|
||||||
|
|
||||||
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
|
||||||
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
|
||||||
|
|
||||||
score = self.attention_out(score)
|
|
||||||
if use_cache:
|
|
||||||
return score, (h_k, h_v)
|
|
||||||
else:
|
|
||||||
return score
|
|
|
@ -1,279 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .attention import Attention
|
|
||||||
from .feedforward import FeedForward
|
|
||||||
from .layernorm import LayerNorm
|
|
||||||
from .position_embedding import RotaryEmbedding
|
|
||||||
from .position_embedding import RotaryEmbeddingESM
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionBlock(torch.nn.Module):
|
|
||||||
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim_model (int): main dimension of modules in transformer blocks.
|
|
||||||
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
|
||||||
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
|
||||||
dtype (optional): Defaults to torch.half.
|
|
||||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
|
||||||
dropout_p (float, optional): Defaults to 0.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_model: int,
|
|
||||||
num_heads: int,
|
|
||||||
dim_head: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
dtype=torch.half,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
dropout_p: Optional[float] = None,
|
|
||||||
scale: bool = True,
|
|
||||||
use_flash_attn: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layernorm_before_attention = LayerNorm(
|
|
||||||
dim_model,
|
|
||||||
dtype=dtype,
|
|
||||||
eps=eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.self_attention = Attention(
|
|
||||||
dim_model=dim_model,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
dim_head=dim_head,
|
|
||||||
dtype=dtype,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
scale=scale,
|
|
||||||
use_flash_attn=use_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
if dropout_p:
|
|
||||||
self.dropout = torch.nn.Dropout(dropout_p)
|
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
position_bias: Union[torch.Tensor, RotaryEmbedding, RotaryEmbeddingESM] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
pos_bias_type: Optional[str] = "relative",
|
|
||||||
length_mask: Optional[torch.Tensor] = None,
|
|
||||||
context_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences.
|
|
||||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation.
|
|
||||||
position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block.
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
x = self.layernorm_before_attention(hidden_states)
|
|
||||||
x = self.self_attention(
|
|
||||||
x,
|
|
||||||
x,
|
|
||||||
attention_mask,
|
|
||||||
position_bias,
|
|
||||||
use_cache,
|
|
||||||
past_key_value,
|
|
||||||
pos_bias_type=pos_bias_type,
|
|
||||||
length_mask=length_mask,
|
|
||||||
context_mask=context_mask,
|
|
||||||
)
|
|
||||||
if use_cache:
|
|
||||||
x, current_key_value = x
|
|
||||||
else:
|
|
||||||
current_key_value = None
|
|
||||||
|
|
||||||
if self.dropout is not None:
|
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + x
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
return hidden_states, current_key_value
|
|
||||||
else:
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FFNBlock(torch.nn.Module):
|
|
||||||
"""The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim_model (int): main dimension of modules in transformer blocks.
|
|
||||||
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
|
||||||
dtype (optional): Defaults to torch.half.
|
|
||||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
|
||||||
dropout_p (float, optional): Defaults to 0.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_model: int,
|
|
||||||
dim_ff: int,
|
|
||||||
activate_fn: str,
|
|
||||||
dtype=torch.half,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
dropout_p: Optional[float] = 0,
|
|
||||||
scale: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layernorm_before_ffn = LayerNorm(
|
|
||||||
dim_model,
|
|
||||||
dtype=dtype,
|
|
||||||
eps=eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ffn = FeedForward(
|
|
||||||
dim_model,
|
|
||||||
dim_ff,
|
|
||||||
activate_fn=activate_fn,
|
|
||||||
dtype=dtype,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if dropout_p:
|
|
||||||
self.dropout = torch.nn.Dropout(dropout_p)
|
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
x = self.layernorm_before_ffn(hidden_states)
|
|
||||||
x = self.ffn(x)
|
|
||||||
if self.dropout is not None:
|
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + x
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(torch.nn.Module):
|
|
||||||
"""The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim_model (int): main dimension of modules in transformer blocks.
|
|
||||||
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
|
||||||
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
|
||||||
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
|
||||||
dtype (optional): Defaults to torch.half.
|
|
||||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
|
||||||
dropout_p (float, optional): Defaults to 0.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_model: int,
|
|
||||||
dim_ff: int,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
dim_head: int,
|
|
||||||
activate_fn: str = "gelu",
|
|
||||||
dtype=torch.half,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
dropout_p: Optional[float] = None,
|
|
||||||
scale: bool = True,
|
|
||||||
mask_att: bool = False,
|
|
||||||
mask_ffn: bool = False,
|
|
||||||
use_flash_attn: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.mask_att = mask_att
|
|
||||||
self.mask_ffn = mask_ffn
|
|
||||||
|
|
||||||
if not self.mask_att:
|
|
||||||
self.self_att = SelfAttentionBlock(
|
|
||||||
dim_model=dim_model,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
dim_head=dim_head,
|
|
||||||
dtype=dtype,
|
|
||||||
eps=eps,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
scale=scale,
|
|
||||||
use_flash_attn=use_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.mask_ffn:
|
|
||||||
self.ffn = FFNBlock(
|
|
||||||
dim_model=dim_model,
|
|
||||||
dim_ff=dim_ff,
|
|
||||||
activate_fn=activate_fn,
|
|
||||||
dtype=dtype,
|
|
||||||
eps=eps,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
self_hidden_states: torch.Tensor,
|
|
||||||
self_attention_mask: torch.Tensor,
|
|
||||||
self_position_bias: Optional[torch.Tensor] = None,
|
|
||||||
use_cache: bool = False,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
pos_bias_type: Optional[str] = "relative",
|
|
||||||
length_mask: Optional[torch.Tensor] = None,
|
|
||||||
context_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
|
||||||
self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention.
|
|
||||||
self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block.
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
# (batch, dim_model, seq_self)
|
|
||||||
current_key_value = None
|
|
||||||
if not self.mask_att:
|
|
||||||
hidden_states = self.self_att(
|
|
||||||
self_hidden_states,
|
|
||||||
attention_mask=self_attention_mask,
|
|
||||||
position_bias=self_position_bias,
|
|
||||||
use_cache=use_cache,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
pos_bias_type=pos_bias_type,
|
|
||||||
length_mask=length_mask,
|
|
||||||
context_mask=context_mask,
|
|
||||||
)
|
|
||||||
if use_cache:
|
|
||||||
hidden_states, current_key_value = hidden_states
|
|
||||||
else:
|
|
||||||
hidden_states = self_hidden_states
|
|
||||||
|
|
||||||
# (batch, dim_model, seq_self)
|
|
||||||
if not self.mask_ffn:
|
|
||||||
hidden_states = self.ffn(hidden_states)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
return hidden_states, current_key_value
|
|
||||||
else:
|
|
||||||
return hidden_states
|
|
|
@ -1,100 +0,0 @@
|
||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .position_embedding import RotaryEmbedding
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
embedding_size: int,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
scale: bool = True,
|
|
||||||
init_mean: float = 0.0,
|
|
||||||
init_std: float = 1,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim_model = embedding_size
|
|
||||||
self.weight = torch.nn.parameter.Parameter(torch.empty(vocab_size, embedding_size, dtype=dtype))
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, ids: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
|
||||||
else:
|
|
||||||
embeds = F.embedding(ids, self.weight)
|
|
||||||
return embeds.clone()
|
|
||||||
|
|
||||||
def projection(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
|
||||||
Args:
|
|
||||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
|
||||||
Returns:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
|
|
||||||
""" # noqa: E501
|
|
||||||
if self.scale:
|
|
||||||
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
|
||||||
else:
|
|
||||||
logits = F.linear(x, self.weight)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingExt(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
embedding_size: int,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
init_mean: float = 0.0,
|
|
||||||
init_std: float = 1,
|
|
||||||
distance_scale: int = 16,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim_model = embedding_size
|
|
||||||
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
|
|
||||||
|
|
||||||
self.weight = torch.nn.parameter.Parameter(
|
|
||||||
torch.empty(vocab_size, embedding_size, dtype=dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
|
||||||
ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
|
||||||
return self.rotary_emb(embeds, ids_sub)
|
|
||||||
|
|
||||||
def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
|
|
||||||
"""
|
|
||||||
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
|
||||||
Args:
|
|
||||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
|
||||||
ext_table (:obj:`torch.Tensor` of shape ``(ext_table_size, dim_model)``): Ext vocab table.
|
|
||||||
Returns:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_size + ext_table_size)``: The projection output.
|
|
||||||
""" # noqa: E501
|
|
||||||
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
|
||||||
if ext_table is not None:
|
|
||||||
logits_ext = F.linear(x, ext_table)
|
|
||||||
logits = torch.cat([logits, logits_ext], dim=-1)
|
|
||||||
return logits
|
|
|
@ -1,120 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .linear import Linear
|
|
||||||
|
|
||||||
|
|
||||||
class DenseGatedACT(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_in: int,
|
|
||||||
dim_ff: int,
|
|
||||||
dtype=torch.half,
|
|
||||||
activate_fn: str = "gelu",
|
|
||||||
scale: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.w_0 = Linear(
|
|
||||||
dim_in=dim_in,
|
|
||||||
dim_out=dim_ff,
|
|
||||||
dtype=dtype,
|
|
||||||
scale=scale,
|
|
||||||
scale_before=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.w_1 = Linear(
|
|
||||||
dim_in=dim_in,
|
|
||||||
dim_out=dim_ff,
|
|
||||||
dtype=dtype,
|
|
||||||
scale=scale,
|
|
||||||
scale_before=False,
|
|
||||||
)
|
|
||||||
if activate_fn == "gelu":
|
|
||||||
self.act = torch.nn.GELU()
|
|
||||||
elif activate_fn == "silu":
|
|
||||||
self.act = torch.nn.functional.silu
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"{activate_fn} is not supported")
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
out (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_ff)``)
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
gate_score = self.act(self.w_0(x))
|
|
||||||
x = self.w_1(x)
|
|
||||||
|
|
||||||
x = gate_score * x
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(torch.nn.Module):
|
|
||||||
r"""FeedForward module
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim_in (int): input dimension.
|
|
||||||
dim_ff (int): middle dimension.
|
|
||||||
dim_out (int, optional): output dimension. Defaults to None, which means dim_in = dim_out.
|
|
||||||
dtype (optional): Defaults to torch.half.
|
|
||||||
init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.
|
|
||||||
init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02.
|
|
||||||
bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False.
|
|
||||||
activate_fn (str, optional): Defaults to `gated_gelu`.
|
|
||||||
dropout_p (int, optional): Defaults to 0.
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_model: int,
|
|
||||||
dim_ff: int,
|
|
||||||
activate_fn: str = "gelu",
|
|
||||||
dtype=torch.half,
|
|
||||||
dropout_p: Optional[float] = None,
|
|
||||||
scale: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.w_in = DenseGatedACT(
|
|
||||||
dim_in=dim_model,
|
|
||||||
dim_ff=dim_ff,
|
|
||||||
activate_fn=activate_fn,
|
|
||||||
dtype=dtype,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if dropout_p is not None:
|
|
||||||
self.dropout = torch.nn.Dropout(dropout_p)
|
|
||||||
else:
|
|
||||||
self.dropout = None
|
|
||||||
|
|
||||||
self.w_out = Linear(
|
|
||||||
dim_in=dim_ff,
|
|
||||||
dim_out=dim_model,
|
|
||||||
dtype=dtype,
|
|
||||||
scale=scale,
|
|
||||||
scale_before=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of feed-forward module.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of feed-forward module.
|
|
||||||
""" # noqa: E501
|
|
||||||
x = self.w_in(x)
|
|
||||||
|
|
||||||
if self.dropout is not None:
|
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
x = self.w_out(x)
|
|
||||||
|
|
||||||
return x
|
|
|
@ -1,37 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
|
||||||
old_dtype = hidden.dtype
|
|
||||||
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
|
||||||
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
|
|
||||||
return hidden * weight
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.Module):
|
|
||||||
"""RMS LayerNorm"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_norm: int,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
init_var: float = 1.0,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.eps = eps
|
|
||||||
self.dim_norm = dim_norm
|
|
||||||
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var, dtype=dtype))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``): Input tensor that need to be normalized.
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output.
|
|
||||||
"""
|
|
||||||
assert x.size(-1) == self.dim_norm
|
|
||||||
return rms_layernorm(x, self.weight, self.eps)
|
|
|
@ -1,44 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_in: int,
|
|
||||||
dim_out: int,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
init_mean: float = 0.0,
|
|
||||||
init_std: float = 1,
|
|
||||||
scale: bool = True,
|
|
||||||
scale_before: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dim_in = self.in_features = dim_in
|
|
||||||
self.dim_out = self.out_features = dim_out
|
|
||||||
self.scale = scale
|
|
||||||
self.scale_before = scale_before
|
|
||||||
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
|
|
||||||
torch.nn.init.normal_(self.weight, mean=init_mean, std=init_std)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
|
||||||
Returns:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
|
||||||
""" # noqa: E501
|
|
||||||
if self.scale:
|
|
||||||
if self.scale_before:
|
|
||||||
x = x / math.sqrt(self.dim_in)
|
|
||||||
|
|
||||||
x = F.linear(x, self.weight)
|
|
||||||
else:
|
|
||||||
x = F.linear(x, self.weight)
|
|
||||||
x = x / math.sqrt(self.dim_in)
|
|
||||||
|
|
||||||
else:
|
|
||||||
x = F.linear(x, self.weight)
|
|
||||||
return x
|
|
|
@ -1,286 +0,0 @@
|
||||||
import math
|
|
||||||
from typing import Tuple
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class SegmentPositionEmbedding(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
num_segments: int = 1,
|
|
||||||
num_buckets: int = 32,
|
|
||||||
max_distance: int = 128,
|
|
||||||
bidirectional: bool = False,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
init_mean: float = 0.0,
|
|
||||||
init_std: float = 1,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_buckets = num_buckets
|
|
||||||
self.max_distance = max_distance
|
|
||||||
self.bidirectional = bidirectional
|
|
||||||
self.num_segments = num_segments
|
|
||||||
|
|
||||||
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
|
||||||
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
key_pos: torch.Tensor,
|
|
||||||
query_pos: torch.Tensor,
|
|
||||||
key_segment: torch.Tensor,
|
|
||||||
query_segment: torch.Tensor,
|
|
||||||
):
|
|
||||||
with torch.no_grad():
|
|
||||||
batch = key_pos.size(0)
|
|
||||||
keylen = key_pos.size(1)
|
|
||||||
querylen = query_pos.size(1)
|
|
||||||
|
|
||||||
assert key_pos.size(0) == query_pos.size(0)
|
|
||||||
assert keylen == key_segment.size(1) and querylen == query_segment.size(1)
|
|
||||||
|
|
||||||
key_pos = key_pos.view(batch, -1, keylen)
|
|
||||||
query_pos = query_pos.view(batch, querylen, -1)
|
|
||||||
key_segment = key_segment.view(batch, -1, keylen)
|
|
||||||
query_segment = query_segment.view(batch, querylen, -1)
|
|
||||||
|
|
||||||
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
|
|
||||||
relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠
|
|
||||||
|
|
||||||
# b*q*k
|
|
||||||
absolute_position_bucket = self._position_bucket(
|
|
||||||
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
|
|
||||||
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
|
|
||||||
bidirectional=self.bidirectional,
|
|
||||||
num_buckets=self.num_buckets,
|
|
||||||
max_distance=self.max_distance,
|
|
||||||
)
|
|
||||||
relative_position_bucket = torch.where(
|
|
||||||
(key_segment == query_segment),
|
|
||||||
absolute_position_bucket[None, :, :],
|
|
||||||
relative_position_bucket,
|
|
||||||
)
|
|
||||||
# (batch, len_q, len_k)
|
|
||||||
|
|
||||||
# (batch, len_q, len_k, num_heads)
|
|
||||||
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
|
||||||
# (batch, num_heads, len_q, len_k)
|
|
||||||
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
def _segment_relative_position_bucket(self, query_segment, key_segment):
|
|
||||||
return query_segment * self.num_segments + key_segment
|
|
||||||
|
|
||||||
def _position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
|
||||||
relative_buckets = 0
|
|
||||||
if bidirectional:
|
|
||||||
num_buckets //= 2
|
|
||||||
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
|
||||||
relative_position = torch.abs(relative_position)
|
|
||||||
else:
|
|
||||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
|
||||||
max_exact = num_buckets // 2
|
|
||||||
is_small = relative_position < max_exact
|
|
||||||
relative_postion_if_large = max_exact + (
|
|
||||||
torch.log(relative_position.float() / max_exact)
|
|
||||||
/ math.log(max_distance / max_exact)
|
|
||||||
* (num_buckets - max_exact)
|
|
||||||
).to(torch.int32)
|
|
||||||
relative_postion_if_large = torch.min(
|
|
||||||
relative_postion_if_large,
|
|
||||||
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
|
||||||
)
|
|
||||||
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
|
||||||
return relative_buckets
|
|
||||||
|
|
||||||
|
|
||||||
class BucketPositionBias(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
num_buckets: int = 32,
|
|
||||||
num_segment_bucket: int = 32,
|
|
||||||
max_distance: int = 128,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
init_mean: float = 0.0,
|
|
||||||
init_std: float = 1,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_buckets = num_buckets
|
|
||||||
self.num_segment_bucket = num_segment_bucket
|
|
||||||
self.max_distance = max_distance
|
|
||||||
|
|
||||||
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
|
||||||
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query_pos: torch.Tensor, # (batch, len_q)
|
|
||||||
key_pos: torch.Tensor, # (batch, len_k)
|
|
||||||
rel_buckets: torch.Tensor, # (batch, len_q, len_k)
|
|
||||||
):
|
|
||||||
with torch.no_grad():
|
|
||||||
batch = key_pos.size(0)
|
|
||||||
keylen = key_pos.size(1)
|
|
||||||
querylen = query_pos.size(1)
|
|
||||||
|
|
||||||
assert key_pos.size(0) == query_pos.size(0)
|
|
||||||
assert rel_buckets.size(0) == batch and rel_buckets.size(1) == querylen and rel_buckets.size(2) == keylen
|
|
||||||
|
|
||||||
relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠
|
|
||||||
|
|
||||||
# b*q*k
|
|
||||||
inner_segment_bucket = self._position_bucket(
|
|
||||||
key_pos[..., None, :] - query_pos[..., :, None],
|
|
||||||
num_buckets=self.num_buckets,
|
|
||||||
max_distance=self.max_distance,
|
|
||||||
)
|
|
||||||
relative_position_bucket = torch.where(
|
|
||||||
rel_buckets == 0,
|
|
||||||
inner_segment_bucket,
|
|
||||||
relative_position_bucket,
|
|
||||||
)
|
|
||||||
# (batch, len_q, len_k)
|
|
||||||
|
|
||||||
# (batch, len_q, len_k, num_heads)
|
|
||||||
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
|
||||||
# (batch, num_heads, len_q, len_k)
|
|
||||||
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
|
|
||||||
relative_buckets = 0
|
|
||||||
num_buckets //= 2
|
|
||||||
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
|
||||||
relative_position = torch.abs(relative_position)
|
|
||||||
|
|
||||||
max_exact = num_buckets // 2
|
|
||||||
is_small = relative_position < max_exact
|
|
||||||
relative_postion_if_large = max_exact + (
|
|
||||||
torch.log(relative_position.float() / max_exact)
|
|
||||||
/ math.log(max_distance / max_exact)
|
|
||||||
* (num_buckets - max_exact)
|
|
||||||
).to(torch.int32)
|
|
||||||
relative_postion_if_large = torch.min(
|
|
||||||
relative_postion_if_large,
|
|
||||||
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
|
||||||
)
|
|
||||||
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
|
||||||
return relative_buckets
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
base=10000,
|
|
||||||
distance_scale: Union[int, float] = 1,
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
|
||||||
inv_freq = inv_freq.to(dtype)
|
|
||||||
self.distance_scale = distance_scale
|
|
||||||
self.dtype = dtype
|
|
||||||
self.inv_freq = inv_freq
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, x_pos: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (:obj:`torch.Tensor` of shape ``(..., dim)``): Inputs.
|
|
||||||
x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs.
|
|
||||||
"""
|
|
||||||
x_pos = x_pos * self.distance_scale
|
|
||||||
freqs = x_pos[..., None].to(self.dtype) * self.inv_freq[None, :] # (..., dim/2)
|
|
||||||
|
|
||||||
# the same implementation as sat
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1) # (..., dim)
|
|
||||||
emb_cos = emb.cos() # (..., dim)
|
|
||||||
emb_sin = emb.sin() # (..., dim)
|
|
||||||
|
|
||||||
rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim)
|
|
||||||
|
|
||||||
return x * emb_cos + rotate_x * emb_sin
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset):
|
|
||||||
if x.size(seq_dim) < cos.size(seq_dim):
|
|
||||||
cos = cos.narrow(seq_dim, offset, x.size(seq_dim))
|
|
||||||
sin = sin.narrow(seq_dim, offset, x.size(seq_dim))
|
|
||||||
return (x * cos) + (rotate_half(x) * sin)
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbeddingESM(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Rotary position embeddings based on those in
|
|
||||||
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
|
|
||||||
matrices which depend on their relative positions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
base: Union[int, float] = 10000,
|
|
||||||
distance_scale: Union[int, float] = 1,
|
|
||||||
dtype=torch.half,
|
|
||||||
persistent=True,
|
|
||||||
mixed_precision=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.base = base
|
|
||||||
self.distance_scale = distance_scale
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
# Generate and save the inverse frequency buffer (non trainable)
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
|
||||||
if mixed_precision:
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=persistent)
|
|
||||||
else:
|
|
||||||
self.register_buffer("inv_freq", inv_freq.to(self.dtype), persistent=persistent)
|
|
||||||
|
|
||||||
self._seq_len_cached = -1
|
|
||||||
self._cos_cached = None
|
|
||||||
self._sin_cached = None
|
|
||||||
self.mixed_precision = mixed_precision
|
|
||||||
|
|
||||||
self.apply_rotary_pos_emb = apply_rotary_pos_emb
|
|
||||||
|
|
||||||
def _update_cos_sin_tables(self, x, seq_dim, offset):
|
|
||||||
seq_len = x.size(seq_dim) + offset
|
|
||||||
if seq_len > self._seq_len_cached or self._cos_cached.device != x.device:
|
|
||||||
self._seq_len_cached = seq_len
|
|
||||||
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
|
||||||
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
for i in range(x.dim() - 1):
|
|
||||||
if i != seq_dim:
|
|
||||||
emb = emb.unsqueeze_(i)
|
|
||||||
if self.mixed_precision:
|
|
||||||
self._cos_cached = emb.cos().to(self.dtype)
|
|
||||||
self._sin_cached = emb.sin().to(self.dtype)
|
|
||||||
else:
|
|
||||||
self._cos_cached = emb.cos()
|
|
||||||
self._sin_cached = emb.sin()
|
|
||||||
return self._cos_cached, self._sin_cached
|
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
seq_dim = (seq_dim + k.dim()) % k.dim()
|
|
||||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, offset)
|
|
||||||
return (
|
|
||||||
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
|
|
||||||
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
|
|
||||||
)
|
|
|
@ -1,132 +0,0 @@
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .blocks import TransformerBlock
|
|
||||||
from .layernorm import LayerNorm
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(torch.nn.Module):
|
|
||||||
"""Layers of encoder transformer blocks plus an final layernorm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_layers (int): number of layers.
|
|
||||||
dim_model (int): main dimension of modules in transformer blocks.
|
|
||||||
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
|
||||||
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
|
||||||
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
|
||||||
dtype (optional): Defaults to torch.half.
|
|
||||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
|
|
||||||
dropout_p (float, optional): Defaults to 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_layers: int,
|
|
||||||
dim_model: int,
|
|
||||||
dim_ff: int,
|
|
||||||
num_heads: int,
|
|
||||||
dim_head: int,
|
|
||||||
num_kv_heads: int = -1,
|
|
||||||
activate_fn: str = "gelu",
|
|
||||||
dtype: torch.dtype = torch.half,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
dropout_p: Optional[float] = None,
|
|
||||||
scale: bool = True,
|
|
||||||
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
|
||||||
use_flash_attn: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if num_kv_heads == -1:
|
|
||||||
num_kv_heads = num_heads
|
|
||||||
self.num_layers = num_layers
|
|
||||||
|
|
||||||
if mask_modules is not None:
|
|
||||||
assert len(mask_modules) == num_layers, "The total number of masks should equal to num_layers"
|
|
||||||
for mask_module in mask_modules:
|
|
||||||
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
|
|
||||||
else:
|
|
||||||
mask_modules = [(False, False)] * num_layers
|
|
||||||
|
|
||||||
self.layers = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
TransformerBlock(
|
|
||||||
dim_model=dim_model,
|
|
||||||
dim_ff=dim_ff,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
dim_head=dim_head,
|
|
||||||
activate_fn=activate_fn,
|
|
||||||
dtype=dtype,
|
|
||||||
eps=eps,
|
|
||||||
dropout_p=dropout_p,
|
|
||||||
scale=scale,
|
|
||||||
mask_att=mask_modules[ith][0],
|
|
||||||
mask_ffn=mask_modules[ith][1],
|
|
||||||
use_flash_attn=use_flash_attn,
|
|
||||||
)
|
|
||||||
for ith in range(num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
position_bias: torch.Tensor,
|
|
||||||
use_cache: bool = False,
|
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
||||||
pos_bias_type: Optional[str] = "relative",
|
|
||||||
length_mask: Optional[torch.Tensor] = None,
|
|
||||||
context_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden-states (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Input of encoder, might be the embedding of a batch of sequences.
|
|
||||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc, seq_enc)``): Avoid invalid areas to participate in the calculation
|
|
||||||
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, seq_enc, seq_enc)``) Provides position information to attention mechanism.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``: The encoder output.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not use_cache:
|
|
||||||
for layer in self.layers:
|
|
||||||
hidden_states = layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_bias,
|
|
||||||
pos_bias_type=pos_bias_type,
|
|
||||||
length_mask=length_mask,
|
|
||||||
context_mask=context_mask,
|
|
||||||
)
|
|
||||||
hidden_states = self.output_layernorm(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
current_key_values = []
|
|
||||||
current_hidden_states = []
|
|
||||||
for i, module in enumerate(self.layers):
|
|
||||||
hidden_states = module(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_bias,
|
|
||||||
past_key_value=past_key_values[i] if past_key_values else None,
|
|
||||||
use_cache=use_cache,
|
|
||||||
pos_bias_type=pos_bias_type,
|
|
||||||
length_mask=length_mask,
|
|
||||||
context_mask=context_mask,
|
|
||||||
)
|
|
||||||
if use_cache:
|
|
||||||
current_key_values.append(hidden_states[1])
|
|
||||||
current_hidden_states.append(hidden_states[0])
|
|
||||||
hidden_states = hidden_states[0]
|
|
||||||
hidden_states = self.output_layernorm(hidden_states)
|
|
||||||
if use_cache:
|
|
||||||
return hidden_states, current_key_values, current_hidden_states
|
|
||||||
else:
|
|
||||||
return hidden_states
|
|
|
@ -1,62 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def pad(orig_items, key, padding_value=0, padding_side="left"):
|
|
||||||
items = []
|
|
||||||
if isinstance(orig_items[0][key], list):
|
|
||||||
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
|
||||||
for it in orig_items:
|
|
||||||
for tr in it[key]:
|
|
||||||
items.append({key: tr})
|
|
||||||
else:
|
|
||||||
assert isinstance(orig_items[0][key], torch.Tensor)
|
|
||||||
items = orig_items
|
|
||||||
|
|
||||||
batch_size = len(items)
|
|
||||||
shape = items[0][key].shape
|
|
||||||
dim = len(shape)
|
|
||||||
assert dim <= 3
|
|
||||||
max_length = max(item[key].shape[-1] for item in items)
|
|
||||||
min_length = min(item[key].shape[-1] for item in items)
|
|
||||||
dtype = items[0][key].dtype
|
|
||||||
|
|
||||||
if dim == 1:
|
|
||||||
return torch.cat([item[key] for item in items], dim=0)
|
|
||||||
elif dim == 2:
|
|
||||||
if max_length == min_length:
|
|
||||||
return torch.cat([item[key] for item in items], dim=0)
|
|
||||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
|
||||||
else:
|
|
||||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
|
||||||
|
|
||||||
for i, item in enumerate(items):
|
|
||||||
if dim == 2:
|
|
||||||
if padding_side == "left":
|
|
||||||
tensor[i, -len(item[key][0]) :] = item[key][0].clone()
|
|
||||||
else:
|
|
||||||
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
|
||||||
elif dim == 3:
|
|
||||||
if padding_side == "left":
|
|
||||||
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
|
||||||
else:
|
|
||||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
|
||||||
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def pad_raw(orig_items, max_length=1024, padding_value=0, padding_side="left"):
|
|
||||||
max_cols = max(tensor.size(1) for tensor in orig_items)
|
|
||||||
|
|
||||||
padded_arrays = []
|
|
||||||
for tensor in orig_items:
|
|
||||||
pad_cols = max_cols - tensor.size(1)
|
|
||||||
if padding_side == "left":
|
|
||||||
padded_tensor = torch.cat([torch.zeros(tensor.size(0), pad_cols), tensor], dim=1)
|
|
||||||
elif padding_side == "right":
|
|
||||||
padded_tensor = torch.cat([tensor, torch.zeros(tensor.size(0), pad_cols)], dim=1)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid 'side' parameter. Must be 'left' or 'right'.")
|
|
||||||
padded_arrays.append(padded_tensor)
|
|
||||||
|
|
||||||
padded_tensor = torch.cat(padded_arrays, dim=0).to(dtype=torch.int32)
|
|
||||||
return padded_tensor
|
|
|
@ -1,130 +0,0 @@
|
||||||
import functools
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .log import logger
|
|
||||||
|
|
||||||
|
|
||||||
def rename_if_exists(file_path):
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
return
|
|
||||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
|
||||||
file_dir, file_name = os.path.split(file_path)
|
|
||||||
file_root, file_ext = os.path.splitext(file_name)
|
|
||||||
new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
|
|
||||||
new_file_path = os.path.join(file_dir, new_file_name)
|
|
||||||
try:
|
|
||||||
os.rename(file_path, new_file_path)
|
|
||||||
logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn(
|
|
||||||
"rename file failed,file_path={file_path}, new_file_path={new_file_path},err={err}".format(
|
|
||||||
file_path=file_path, new_file_path=new_file_path, err=str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rename_if_exists_decorator(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(file_path, *args, **kwargs):
|
|
||||||
rename_if_exists(file_path)
|
|
||||||
return func(file_path, *args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@rename_if_exists_decorator
|
|
||||||
def bmt_save(file_path: str, model: torch.nn.Module, export_files: List[str] = None):
|
|
||||||
bmt.save(model, file_path)
|
|
||||||
if export_files is not None:
|
|
||||||
export_files.append(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
@rename_if_exists_decorator
|
|
||||||
def torch_save(file_path: str, obj: object, export_files: List[str] = None):
|
|
||||||
torch.save(obj, file_path)
|
|
||||||
if export_files is not None:
|
|
||||||
export_files.append(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
@rename_if_exists_decorator
|
|
||||||
def json_save(file_path: str, obj: object, export_files: List[str] = None):
|
|
||||||
with open(file_path, "w") as data_f:
|
|
||||||
json.dump(obj, data_f)
|
|
||||||
if export_files is not None:
|
|
||||||
export_files.append(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def export(
|
|
||||||
model: torch.nn.Module, dataloader, optimizer: bmt.optim.AdamOffloadOptimizer, global_step, args, final_save=False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
一次 ckpt 保存:
|
|
||||||
/{args.save}/
|
|
||||||
├── {save_name}-{global_step}.rank-0.opt
|
|
||||||
├── {save_name}-{global_step}.rank-n.opt
|
|
||||||
├── job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时,job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
|
|
||||||
├── config.json
|
|
||||||
├── vocabs.txt
|
|
||||||
├── {args.save_name}-{global_step}.pt
|
|
||||||
├── {args.save_name}-{global_step}.data
|
|
||||||
├── {args.save_name}-{global_step}.data.json
|
|
||||||
└── {args.save_name}-{global_step}.success
|
|
||||||
|
|
||||||
"""
|
|
||||||
export_model_dir = os.path.join(args.save, f"l_{global_step}")
|
|
||||||
os.makedirs(export_model_dir, exist_ok=True)
|
|
||||||
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
|
|
||||||
logger.info(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
|
|
||||||
export_files = []
|
|
||||||
|
|
||||||
# model checkpoint
|
|
||||||
bmt_save(
|
|
||||||
file_path=os.path.join(export_model_dir, base_file_name + ".pt"),
|
|
||||||
model=model,
|
|
||||||
export_files=export_files,
|
|
||||||
)
|
|
||||||
|
|
||||||
# opt is only used for continual pre-training, not the final opt
|
|
||||||
if not final_save:
|
|
||||||
grad_path = os.path.join(
|
|
||||||
args.save,
|
|
||||||
args.save_name + ("-%d.rank-%d.opt" % (global_step % (args.save_iters * 5), bmt.rank())),
|
|
||||||
)
|
|
||||||
torch.save(optimizer.state_dict(), grad_path)
|
|
||||||
logger.info(f"Successfully save grad file: {grad_path}")
|
|
||||||
|
|
||||||
all_states = dataloader.state_dict()
|
|
||||||
if bmt.rank() == 0:
|
|
||||||
# data checkpoint
|
|
||||||
# rank 0 writes the dataloader state
|
|
||||||
torch_save(
|
|
||||||
file_path=os.path.join(export_model_dir, base_file_name + ".data"),
|
|
||||||
obj=all_states,
|
|
||||||
export_files=export_files,
|
|
||||||
)
|
|
||||||
# data checkpoint json
|
|
||||||
# rank 0 writes the dataloader state into the json file
|
|
||||||
data_p_json = {k: v for k, v in all_states.items()}
|
|
||||||
for k in data_p_json:
|
|
||||||
data_p_json[k] = {k_of_v: data_p_json[k][k_of_v].tolist() for k_of_v in data_p_json[k]}
|
|
||||||
json_save(
|
|
||||||
file_path=os.path.join(export_model_dir, base_file_name + ".data.json"),
|
|
||||||
obj=data_p_json,
|
|
||||||
export_files=export_files,
|
|
||||||
)
|
|
||||||
# config 和 vocabs 和模型文件一起存储
|
|
||||||
model_cfg_path = os.path.join(export_model_dir, "config.json")
|
|
||||||
model_vocab_path = os.path.join(export_model_dir, "vocabs.txt")
|
|
||||||
export_files.extend([model_cfg_path, model_vocab_path])
|
|
||||||
shutil.copy(args.model_config, model_cfg_path)
|
|
||||||
shutil.copy(args.vocab, model_vocab_path)
|
|
||||||
logger.info(f"Successfully save model files! {export_files}")
|
|
||||||
del all_states
|
|
||||||
return export_model_dir
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
# !/usr/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024, QiYuan Inc
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def rand(n: int, r: random.Random):
|
||||||
|
return int(r.random() * n)
|
||||||
|
|
||||||
|
def transform(data, num_sample: int, r: random.Random):
|
||||||
|
if 'input' in data:
|
||||||
|
_input = "<用户>"+data['input']+"<AI>"
|
||||||
|
else:
|
||||||
|
_input = ""
|
||||||
|
|
||||||
|
if 'output' in data:
|
||||||
|
_output = data['output']
|
||||||
|
else:
|
||||||
|
_output = ""
|
||||||
|
return {"input": _input,
|
||||||
|
"output": _output,
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def rand(n: int, r: random.Random):
|
||||||
|
return int(r.random() * n)
|
||||||
|
|
||||||
|
def transform(data, num_sample: int, r: random.Random):
|
||||||
|
if 'input' in data:
|
||||||
|
_input = data['input']
|
||||||
|
else:
|
||||||
|
_input = ""
|
||||||
|
|
||||||
|
if 'output' in data:
|
||||||
|
_output = data['output']
|
||||||
|
else:
|
||||||
|
_output = ""
|
||||||
|
return {"input": _input,
|
||||||
|
"output": _output,
|
||||||
|
}
|
|
@ -0,0 +1,134 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"dataset_name": "humanevallike_clean_dedup",
|
||||||
|
"task_name": "humanevallike_clean_dedup",
|
||||||
|
"abs_weight": 0.2,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 995339,
|
||||||
|
"ave_tokens_per_line": 100,
|
||||||
|
"total_tokens": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "leetcode_pass_code_0125",
|
||||||
|
"task_name": "leetcode_pass_code_0125",
|
||||||
|
"abs_weight": 0.006,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 10724,
|
||||||
|
"ave_tokens_per_line": 200,
|
||||||
|
"total_tokens": 0.002
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "logiv2Annotate",
|
||||||
|
"task_name": "logiv2Annotate",
|
||||||
|
"abs_weight": 0.004,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 12566,
|
||||||
|
"ave_tokens_per_line": 512,
|
||||||
|
"total_tokens": 0.006
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "mmlu_enhance",
|
||||||
|
"task_name": "mmlu_enhance",
|
||||||
|
"abs_weight": 0.1,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 169771,
|
||||||
|
"ave_tokens_per_line": 300,
|
||||||
|
"total_tokens": 0.05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "mtbench_like",
|
||||||
|
"task_name": "mtbench_like",
|
||||||
|
"abs_weight": 0.2,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 319080,
|
||||||
|
"ave_tokens_per_line": 500,
|
||||||
|
"total_tokens": 0.15
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "ultra_dataset_new",
|
||||||
|
"task_name": "ultra_dataset_new",
|
||||||
|
"abs_weight": 2.0,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 385045,
|
||||||
|
"ave_tokens_per_line": 200.296266559615,
|
||||||
|
"total_tokens": 2.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "sft_data_zh_wowru",
|
||||||
|
"task_name": "sft_data_zh_wowru",
|
||||||
|
"abs_weight": 1.0,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 2963260,
|
||||||
|
"ave_tokens_per_line": 200.296266559615,
|
||||||
|
"total_tokens": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "math_data",
|
||||||
|
"task_name": "math_data",
|
||||||
|
"abs_weight": 0.003,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
||||||
|
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 2963260,
|
||||||
|
"ave_tokens_per_line": 200.296266559615,
|
||||||
|
"total_tokens": 0.005
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "t0",
|
||||||
|
"task_name": "t0",
|
||||||
|
"abs_weight": 0.1,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
||||||
|
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 1650309,
|
||||||
|
"ave_tokens_per_line": 500.296266559615,
|
||||||
|
"total_tokens": 0.82
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "wikihow",
|
||||||
|
"task_name": "wikihow",
|
||||||
|
"abs_weight": 0.1,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 180128,
|
||||||
|
"ave_tokens_per_line": 900.296266559615,
|
||||||
|
"total_tokens": 0.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "reclor",
|
||||||
|
"task_name": "reclor",
|
||||||
|
"abs_weight": 0.002,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 4174,
|
||||||
|
"ave_tokens_per_line": 700.296266559615,
|
||||||
|
"total_tokens": 0.003
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "logic_test_lx_0127",
|
||||||
|
"task_name": "logic_test_lx_0127",
|
||||||
|
"abs_weight": 0.001,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
||||||
|
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 2800,
|
||||||
|
"ave_tokens_per_line": 200.96266559615,
|
||||||
|
"total_tokens": 0.0004
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,28 @@
|
||||||
|
{
|
||||||
|
"vocab_size": 122753,
|
||||||
|
"dropout_p": 0.0,
|
||||||
|
"eps": 1e-05,
|
||||||
|
"half": true,
|
||||||
|
"half_type": "bf16",
|
||||||
|
"use_flash_attn": true,
|
||||||
|
"flash_attn_mask_shape": "2d",
|
||||||
|
"dim_model": 2304,
|
||||||
|
"dim_ff": 5760,
|
||||||
|
"dim_head": 64,
|
||||||
|
"num_heads": 36,
|
||||||
|
"num_kv_heads": 36,
|
||||||
|
"num_layers": 40,
|
||||||
|
"activate_fn": "silu",
|
||||||
|
"init_std": 0.10,
|
||||||
|
"scale": true,
|
||||||
|
"scale_emb": 12,
|
||||||
|
"scale_depth": 1.4,
|
||||||
|
"dim_model_base": 256,
|
||||||
|
"model_type": "fm9g",
|
||||||
|
"architectures": [
|
||||||
|
"FM9GForCausalLM"
|
||||||
|
],
|
||||||
|
"qk_norm": false,
|
||||||
|
"tie_lm_head": true,
|
||||||
|
"ffn_gated": true
|
||||||
|
}
|
|
@ -0,0 +1,548 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 QiYuan Inc.
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from bmtrain import nccl
|
||||||
|
from bmtrain.global_var import config as bmt_config
|
||||||
|
|
||||||
|
sys.path.append("../../")
|
||||||
|
from fm9g.arguments import get_args
|
||||||
|
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
||||||
|
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
||||||
|
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
|
||||||
|
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
|
||||||
|
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
|
||||||
|
from fm9g.utils import exporter
|
||||||
|
from fm9g.utils import logger
|
||||||
|
from fm9g.utils.exporter import save_every_step_stats
|
||||||
|
from fm9g.utils.training_stats import num_non_embedding_parameters
|
||||||
|
from fm9g.utils.training_stats import num_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(args):
|
||||||
|
from transformers import LlamaTokenizerFast
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizerFast(vocab_file=args.tokenizer_path)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(args):
|
||||||
|
config = DragonflyConfig.from_json_file(args.model_config)
|
||||||
|
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
||||||
|
config.pose_prob = args.pose_prob
|
||||||
|
config.pose_scaling_factor = args.pose_scaling_factor
|
||||||
|
config.rope_scaling_type = args.rope_scaling_type
|
||||||
|
config.rope_scaling_factor = args.rope_scaling_factor
|
||||||
|
config.orig_max_length = args.orig_max_length
|
||||||
|
|
||||||
|
bmt.print_rank("model config: {}".format(config))
|
||||||
|
bmt.print_rank("bmt config: {}".format(bmt.config))
|
||||||
|
|
||||||
|
model = Dragonfly(config)
|
||||||
|
if args.load is not None:
|
||||||
|
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||||
|
exporter.load_model_ckpt(args, model)
|
||||||
|
else:
|
||||||
|
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||||
|
bmt.init_parameters(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer(args, model):
|
||||||
|
scale_lr_group = []
|
||||||
|
normal_group = []
|
||||||
|
scale_lr_group_name, normal_group_name = [], []
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
||||||
|
scale_lr_group.append(p)
|
||||||
|
scale_lr_group_name.append(n)
|
||||||
|
else:
|
||||||
|
normal_group.append(p)
|
||||||
|
normal_group_name.append(n)
|
||||||
|
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
||||||
|
param_groups = [
|
||||||
|
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
||||||
|
{"params": normal_group, "lr": args.lr},
|
||||||
|
]
|
||||||
|
|
||||||
|
if args.offload:
|
||||||
|
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||||
|
else:
|
||||||
|
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||||
|
if args.load is not None and args.load_grad:
|
||||||
|
exporter.load_optimizer_ckpt(args, optimizer)
|
||||||
|
bmt.print_rank("optimizer is loaded!")
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_learning_rate_scheduler(args, optimizer):
|
||||||
|
from fm9g.training_utils.lr_scheduler import Cosine
|
||||||
|
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
||||||
|
|
||||||
|
end_iter = args.train_iters
|
||||||
|
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
||||||
|
warmup_iters = int(end_iter * args.warmup_iters)
|
||||||
|
else:
|
||||||
|
warmup_iters = int(args.warmup_iters)
|
||||||
|
|
||||||
|
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
||||||
|
drop_iters = int(end_iter * args.drop_iters)
|
||||||
|
else:
|
||||||
|
drop_iters = int(args.drop_iters)
|
||||||
|
|
||||||
|
if args.lr_scheduler == "cosine":
|
||||||
|
lr_scheduler = Cosine(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=warmup_iters,
|
||||||
|
end_iter=end_iter, # 原来是lr_decay_iter
|
||||||
|
num_iter=args.start_step,
|
||||||
|
#lr_end_restart=args.lr_end_restart,
|
||||||
|
#resume_no_optimze=args.resume_no_optimze,
|
||||||
|
)
|
||||||
|
elif args.lr_scheduler == "warmupstabledrop":
|
||||||
|
lr_scheduler = WarmupStableDrop(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=warmup_iters,
|
||||||
|
end_iter=end_iter, # 原来是lr_decay_iter
|
||||||
|
drop_iter=drop_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
resume_no_optimze=args.resume_no_optimze,
|
||||||
|
)
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model_and_optimizer(args):
|
||||||
|
start = time.time()
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
model = get_model(args)
|
||||||
|
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
optimizer = get_optimizer(args, model)
|
||||||
|
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
return tokenizer, model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def resume_training(args):
|
||||||
|
ckpts = sorted(
|
||||||
|
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
||||||
|
reverse=True,
|
||||||
|
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
||||||
|
)
|
||||||
|
# find newest job
|
||||||
|
ckpts = sorted(
|
||||||
|
ckpts,
|
||||||
|
reverse=True,
|
||||||
|
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(ckpts) > 0:
|
||||||
|
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
||||||
|
args.load = ckpts[0]
|
||||||
|
# by default, do not load grad file
|
||||||
|
args.load_grad = False
|
||||||
|
args.start_step = 0
|
||||||
|
else:
|
||||||
|
# no ckpts, nothing we can do
|
||||||
|
os._exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
args = get_args(pretrain=True)
|
||||||
|
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
||||||
|
|
||||||
|
if args.save is not None:
|
||||||
|
os.makedirs(args.save, exist_ok=True)
|
||||||
|
if args.load is not None:
|
||||||
|
if args.only_load_model == 0:
|
||||||
|
if args.start_step == 0:
|
||||||
|
log_ckpt = exporter.load_log_ckpt(args)
|
||||||
|
if "iteration" in log_ckpt:
|
||||||
|
args.start_step = log_ckpt["iteration"]
|
||||||
|
else:
|
||||||
|
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
||||||
|
logger.info("Start from step {}".format(args.start_step))
|
||||||
|
elif args.only_load_model == 1:
|
||||||
|
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
logger.info("You do not load model")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def see_memory(detail=False):
|
||||||
|
if detail:
|
||||||
|
res = torch.cuda.memory_summary()
|
||||||
|
else:
|
||||||
|
res = (
|
||||||
|
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||||
|
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||||
|
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||||
|
)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def add_mem_time(info, mem_usage, tim_usage):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
bmt.synchronize()
|
||||||
|
mem_usage[info] = see_memory()
|
||||||
|
tim_usage[info] = time.time()
|
||||||
|
return mem_usage, tim_usage
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
||||||
|
# task_ids 可能有-1 来代表无效token
|
||||||
|
_task_num = task_num + 1
|
||||||
|
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
||||||
|
# gen masks
|
||||||
|
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
||||||
|
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
||||||
|
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
||||||
|
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
||||||
|
# calc loss and tokens
|
||||||
|
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||||
|
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||||
|
# return token-wise avg losses and tokens
|
||||||
|
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkAve:
|
||||||
|
def __init__(self, chunk_size=100):
|
||||||
|
self.ave_list = []
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
def record(self, time):
|
||||||
|
self.ave_list.append(time)
|
||||||
|
self.ave_list = self.ave_list[-self.chunk_size :]
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return sum(self.ave_list) / len(self.ave_list)
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain(
|
||||||
|
args,
|
||||||
|
tokenizer,
|
||||||
|
model: Dragonfly,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||||
|
):
|
||||||
|
ave_model_time = ChunkAve(chunk_size=100)
|
||||||
|
ave_iter_time = ChunkAve(chunk_size=100)
|
||||||
|
|
||||||
|
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||||
|
optim_manager = bmt.optim.OptimManager(
|
||||||
|
loss_scale=None,
|
||||||
|
loss_scale_steps=args.loss_scale_steps,
|
||||||
|
loss_scale_factor=2,
|
||||||
|
max_loss_scale=args.max_loss_scale,
|
||||||
|
min_loss_scale=args.min_loss_scale,
|
||||||
|
)
|
||||||
|
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
start_step = args.start_step
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
import distutils.version # noqa: F401
|
||||||
|
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
if not os.path.exists(args.tensorboard):
|
||||||
|
os.makedirs(args.tensorboard)
|
||||||
|
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||||
|
|
||||||
|
if args.load is not None:
|
||||||
|
log_ckpt = exporter.load_log_ckpt(args)
|
||||||
|
else:
|
||||||
|
log_ckpt = {}
|
||||||
|
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
||||||
|
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
||||||
|
|
||||||
|
global_world_size = bmt.world_size()
|
||||||
|
bmt.print_rank("Begin preparing dataset")
|
||||||
|
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||||
|
mixed_indexed_dataset = MixedIndexedDataset(
|
||||||
|
cfg_path=args.dataset,
|
||||||
|
cfg_json_str=None,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=args.max_length,
|
||||||
|
nthreads=args.dataloader_num_threads,
|
||||||
|
prefetch_slice=args.dataloader_prefetch,
|
||||||
|
weight_by_size=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
||||||
|
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
||||||
|
|
||||||
|
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
batched_dataset,
|
||||||
|
batch_size=None,
|
||||||
|
collate_fn=lambda x: x,
|
||||||
|
num_workers=args.dataloader_num_workers,
|
||||||
|
prefetch_factor=args.dataloader_prefetch_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def dummy_generator():
|
||||||
|
while True:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
mixed_indexed_dataset = dummy_generator()
|
||||||
|
dataloader = mixed_indexed_dataset
|
||||||
|
|
||||||
|
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
||||||
|
|
||||||
|
bmt.print_rank("Preparing dataset done.")
|
||||||
|
|
||||||
|
# inspect at init
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
|
||||||
|
try:
|
||||||
|
mem_usage, tim_usage = {}, {}
|
||||||
|
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
||||||
|
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||||
|
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
||||||
|
|
||||||
|
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
logits = model(
|
||||||
|
input=data["inputs"],
|
||||||
|
cu_seqlens=data["cu_seqlens"],
|
||||||
|
max_seqlen=data["max_seqlen"],
|
||||||
|
position_ids=data["position_ids"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# chunk targets and task_ids
|
||||||
|
data["targets"] = (
|
||||||
|
data["targets"]
|
||||||
|
.view(-1)
|
||||||
|
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||||
|
.view(data["targets"].shape[0], -1)
|
||||||
|
)
|
||||||
|
data["task_ids"] = (
|
||||||
|
data["task_ids"]
|
||||||
|
.view(-1)
|
||||||
|
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||||
|
.view(data["task_ids"].shape[0], -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
_target = data["targets"].view(-1)
|
||||||
|
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
||||||
|
_w = (_target != -100).int()
|
||||||
|
loss = non_reduced_loss.sum() / _w.sum().float()
|
||||||
|
|
||||||
|
global_loss = bmt.sum_loss(loss).item()
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
optim_manager.backward(loss)
|
||||||
|
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
||||||
|
grad_accum_init_time = tim_usage["init"]
|
||||||
|
|
||||||
|
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||||
|
optim_manager.step()
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||||
|
model_time = tim_usage["optim"] - grad_accum_init_time
|
||||||
|
ave_model_time.record(model_time)
|
||||||
|
else:
|
||||||
|
# dummy optim step
|
||||||
|
grad_norm = torch.Tensor([0.0]).cuda()
|
||||||
|
tim_usage["optim"] = tim_usage["backward"]
|
||||||
|
mem_usage["optim"] = mem_usage["backward"]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
task_num = len(data["task_names"])
|
||||||
|
task_loss, task_token = get_task_loss_and_token(
|
||||||
|
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
||||||
|
)
|
||||||
|
task_loss_map: Dict[str, float] = {}
|
||||||
|
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
||||||
|
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
||||||
|
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
||||||
|
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
||||||
|
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
||||||
|
ave_task_loss = sum_task_loss / tot_task_token
|
||||||
|
for i in range(task_num):
|
||||||
|
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
||||||
|
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
||||||
|
|
||||||
|
local_total_rate = torch.Tensor([data["lengths"].float().mean() / args.max_length]).cuda()
|
||||||
|
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||||
|
global_token_pass += (
|
||||||
|
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
)
|
||||||
|
last_before_log_time = tim_usage["before_log"]
|
||||||
|
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
iter_time = tim_usage["before_log"] - last_before_log_time
|
||||||
|
|
||||||
|
ave_iter_time.record(iter_time)
|
||||||
|
|
||||||
|
train_info = {
|
||||||
|
"time": iter_time,
|
||||||
|
"iteration": iteration,
|
||||||
|
"loss": global_loss,
|
||||||
|
"lr": lr_scheduler.current_lr,
|
||||||
|
"token_max": local_total_rate,
|
||||||
|
"token_pass": global_token_pass,
|
||||||
|
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
||||||
|
"grad_norm": grad_norm.item(),
|
||||||
|
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||||
|
"task_loss": task_loss_map,
|
||||||
|
"total_task_token": global_total_task_token,
|
||||||
|
}
|
||||||
|
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
(
|
||||||
|
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time:.2f} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
||||||
|
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
||||||
|
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
||||||
|
).format(
|
||||||
|
iteration=iteration,
|
||||||
|
loss=global_loss,
|
||||||
|
lr=lr_scheduler.current_lr,
|
||||||
|
model_time=model_time,
|
||||||
|
iter_time=iter_time,
|
||||||
|
chunk_ave_time=ave_iter_time.get(),
|
||||||
|
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
||||||
|
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
||||||
|
grad_norm=grad_norm.item(),
|
||||||
|
global_token_pass=global_token_pass_str,
|
||||||
|
mem_usage=max([value for key, value in mem_usage.items()]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"task_loss:\t| "
|
||||||
|
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||||
|
+ " |"
|
||||||
|
)
|
||||||
|
|
||||||
|
if iteration % 10 == 0:
|
||||||
|
bmt.print_rank(
|
||||||
|
"task_tokens (B):\t| "
|
||||||
|
+ " | ".join(
|
||||||
|
[
|
||||||
|
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
||||||
|
for task_name, task_token in global_total_task_token.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ " |"
|
||||||
|
)
|
||||||
|
|
||||||
|
if iteration % args.inspect_iters == 0:
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
|
||||||
|
if args.log_dir is not None and bmt.rank() == 0:
|
||||||
|
if args.save is not None:
|
||||||
|
save_every_step_stats(train_info, args.save)
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||||
|
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||||
|
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||||
|
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||||
|
for task_name, loss in task_loss_map.items():
|
||||||
|
if not math.isnan(loss):
|
||||||
|
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||||
|
|
||||||
|
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||||
|
log_ckpt = {
|
||||||
|
"global_total_task_token": global_total_task_token,
|
||||||
|
"global_token_pass": global_token_pass,
|
||||||
|
"iteration": iteration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.save is not None and iteration % args.save_iters == 0:
|
||||||
|
exporter.export(
|
||||||
|
model,
|
||||||
|
mixed_indexed_dataset,
|
||||||
|
tokenizer,
|
||||||
|
optimizer,
|
||||||
|
iteration,
|
||||||
|
args,
|
||||||
|
log_ckpt=log_ckpt,
|
||||||
|
final_save=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if iteration == args.train_iters and args.stop_when_end == 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"train loop err: {e}")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_k_and_b(number):
|
||||||
|
if number >= 1e9: # 大于或等于10亿
|
||||||
|
b_number = number / 1e9
|
||||||
|
return f"{b_number:.2f}B"
|
||||||
|
elif number >= 1e6: # 大于或等于1百万
|
||||||
|
k_number = number / 1e6
|
||||||
|
return f"{k_number:.2f}M"
|
||||||
|
elif number >= 1e3:
|
||||||
|
k_number = number / 1e3
|
||||||
|
return f"{k_number:.2f}K"
|
||||||
|
else:
|
||||||
|
return str(number)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
bmt.synchronize()
|
||||||
|
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||||
|
bmt.print_rank("finish loading")
|
||||||
|
bmt.print_rank(
|
||||||
|
"Number of parameter {}, Number of non-e parameter {}".format(
|
||||||
|
num_parameters(model), num_non_embedding_parameters(model)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bmt.print_rank("args: {}".format(args))
|
||||||
|
|
||||||
|
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,234 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#export OMP_NUM_THREADS=16
|
||||||
|
|
||||||
|
declare -A args # Declare an associative array to store arguments and values
|
||||||
|
|
||||||
|
args["model_unique"]="2b_0701"
|
||||||
|
args["resume_ckpt"]=""
|
||||||
|
args["config"]="2.4b"
|
||||||
|
args["flash"]="cuda"
|
||||||
|
args["batch_size"]="1"
|
||||||
|
args["max_length"]="4096"
|
||||||
|
args["save_iters"]="500"
|
||||||
|
args["train_iters"]="10"
|
||||||
|
args["dataset_config"]="fm9g_sft"
|
||||||
|
args["local"]="False"
|
||||||
|
args["dataloader"]="indexed"
|
||||||
|
args["save"]="True"
|
||||||
|
args["dataloader_num_threads"]=1
|
||||||
|
args["dataloader_prefetch"]=1
|
||||||
|
args["dataloader_prefetch_factor"]=1
|
||||||
|
args["dataloader_num_workers"]=1
|
||||||
|
args["lr"]="1e-5"
|
||||||
|
args["warmup_iters"]="20"
|
||||||
|
args["drop_iters"]="0.1"
|
||||||
|
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
||||||
|
args["load_grad"]="False"
|
||||||
|
args["grad_ckpt_num"]="160"
|
||||||
|
args["exp_group"]=""
|
||||||
|
args["ignore_cuda_oom"]="1"
|
||||||
|
args["tensorboard_all_tasks"]="0"
|
||||||
|
args["stop_when_end"]="0"
|
||||||
|
args["only_run_dataloader"]="0"
|
||||||
|
args["eps"]="1e-6"
|
||||||
|
args["inspect_iters"]="100"
|
||||||
|
args["strict_state_dict"]="1"
|
||||||
|
args["only_load_model"]="1"
|
||||||
|
args["lr_scheduler"]="cosine"
|
||||||
|
args["resume_no_optimze"]="0"
|
||||||
|
args["tp_size"]="1"
|
||||||
|
args["parallel_load_datastate"]="8"
|
||||||
|
args["async_save"]="False"
|
||||||
|
args["load_dataloader_ckpt"]="0"
|
||||||
|
args["drop_begin"]="-1"
|
||||||
|
args["drop_rate"]="0.5"
|
||||||
|
args["use_checkpoint"]="0"
|
||||||
|
|
||||||
|
|
||||||
|
# Loop through the arguments
|
||||||
|
for ((i=1; i<=$#; i++)); do
|
||||||
|
arg="${!i}"
|
||||||
|
# Check if the argument starts with "--"
|
||||||
|
if [[ "$arg" == --* ]]; then
|
||||||
|
arg_name="${arg:2}" # Remove leading "--"
|
||||||
|
valueid=$((i+1))
|
||||||
|
# Get the value of the argument if it exists
|
||||||
|
if ((i+1 <= $#)); then
|
||||||
|
args["$arg_name"]="${!valueid}"
|
||||||
|
i=$((i+1)) # Skip the next argument (its value)
|
||||||
|
else
|
||||||
|
args["$arg_name"]="" # Set empty value if no value provided
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
||||||
|
while read -r key value; do
|
||||||
|
args["$key"]="$value"
|
||||||
|
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 用cmd arg 再更新一次
|
||||||
|
# Loop through the arguments
|
||||||
|
for ((i=1; i<=$#; i++)); do
|
||||||
|
arg="${!i}"
|
||||||
|
# Check if the argument starts with "--"
|
||||||
|
if [[ "$arg" == --* ]]; then
|
||||||
|
arg_name="${arg:2}" # Remove leading "--"
|
||||||
|
valueid=$((i+1))
|
||||||
|
|
||||||
|
# Get the value of the argument if it exists
|
||||||
|
if ((i+1 <= $#)); then
|
||||||
|
args["$arg_name"]="${!valueid}"
|
||||||
|
i=$((i+1)) # Skip the next argument (its value)
|
||||||
|
else
|
||||||
|
args["$arg_name"]="" # Set empty value if no value provided
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Print the values of the arguments
|
||||||
|
echo "----------- CMD args ----------"
|
||||||
|
for key in "${!args[@]}"; do
|
||||||
|
echo "$key: ${args[$key]}"
|
||||||
|
done
|
||||||
|
echo "--------- END CMD args --------"
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["flash"]} == "triton" ]]; then
|
||||||
|
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
||||||
|
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
||||||
|
echo "triton flash"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
||||||
|
# GPUS_PER_NODE=1
|
||||||
|
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["model_unique"]} == "" ]]; then
|
||||||
|
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
||||||
|
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
||||||
|
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
||||||
|
else
|
||||||
|
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
||||||
|
fi
|
||||||
|
echo "model_unique: "$MODEL_UNIQUE
|
||||||
|
|
||||||
|
# --------------- 运行参数 ---------------
|
||||||
|
|
||||||
|
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
||||||
|
OPTS+=" --batch-size ${args["batch_size"]}"
|
||||||
|
OPTS+=" --train-iters ${args["train_iters"]}"
|
||||||
|
OPTS+=" --save-iters ${args["save_iters"]}"
|
||||||
|
OPTS+=" --save-name fm9g_live_checkpoint"
|
||||||
|
OPTS+=" --max-length ${args["max_length"]}"
|
||||||
|
OPTS+=" --lr ${args["lr"]}"
|
||||||
|
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
||||||
|
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
||||||
|
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
||||||
|
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
||||||
|
OPTS+=" --offload"
|
||||||
|
#OPTS+=" --vocab ./tokenizer/vocab.txt"
|
||||||
|
OPTS+=" --flash ${args["flash"]}"
|
||||||
|
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
||||||
|
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
||||||
|
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
||||||
|
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
||||||
|
OPTS+=" --eps ${args["eps"]}"
|
||||||
|
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
||||||
|
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
||||||
|
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
||||||
|
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
||||||
|
OPTS+=" --weight-decay 0.1"
|
||||||
|
OPTS+=" --tp-size ${args["tp_size"]}"
|
||||||
|
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
||||||
|
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
||||||
|
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
||||||
|
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
||||||
|
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
||||||
|
|
||||||
|
if [[ ${args["load_grad"]} == "True" ]]; then
|
||||||
|
OPTS+=" --load-grad"
|
||||||
|
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["async_save"]} == "True" ]]; then
|
||||||
|
OPTS+=" --async_save"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
||||||
|
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
||||||
|
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
||||||
|
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
||||||
|
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# --------------- 写文件路径 ---------------
|
||||||
|
## checkpoint
|
||||||
|
if [[ ${args["save"]} == "True" ]]; then
|
||||||
|
|
||||||
|
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
||||||
|
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
||||||
|
else
|
||||||
|
echo "won't save model"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
## logs,/local/logs 等价于 ./datalogs(软链)
|
||||||
|
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
||||||
|
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
||||||
|
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["local"]} == "True" ]]; then
|
||||||
|
current_dir=$(pwd)
|
||||||
|
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||||
|
else
|
||||||
|
current_dir=$(pwd)
|
||||||
|
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||||
|
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
||||||
|
if [ "${args["resume_ckpt"]}" != "" ]; then
|
||||||
|
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
||||||
|
else
|
||||||
|
echo "No checkpoint to load"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
filename="pretrain_dragonfly"
|
||||||
|
|
||||||
|
if [[ ${args["local"]} == "True" ]]; then
|
||||||
|
PRETRAIN_ENTRY="$filename.py"
|
||||||
|
else
|
||||||
|
PRETRAIN_ENTRY="$filename.py"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
GPUS_PER_NODE=1
|
||||||
|
NNODES=1
|
||||||
|
RANK=0
|
||||||
|
MASTER_ENDPOINT=g3006
|
||||||
|
MASTER_PORT=23456
|
||||||
|
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||||
|
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||||
|
|
||||||
|
echo "-------final CMD is------"
|
||||||
|
echo "${CMD}"
|
||||||
|
echo "-------final CMD end------"
|
||||||
|
|
||||||
|
$CMD
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"pretrain": {
|
||||||
|
"train_iters": 1000000000,
|
||||||
|
"batch_size": 1,
|
||||||
|
"max_length": 4096,
|
||||||
|
"n_gpus": 8,
|
||||||
|
"lr": 0.01
|
||||||
|
}
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def rand(n: int, r: random.Random):
|
||||||
|
return int(r.random() * n)
|
||||||
|
|
||||||
|
def transform(data, num_sample: int, r: random.Random):
|
||||||
|
if 'input' in data:
|
||||||
|
_input = "<用户>"+data['input']+"<AI>"
|
||||||
|
else:
|
||||||
|
_input = ""
|
||||||
|
|
||||||
|
if 'output' in data:
|
||||||
|
_output = data['output']
|
||||||
|
else:
|
||||||
|
_output = ""
|
||||||
|
return {"input": _input,
|
||||||
|
"output": _output,
|
||||||
|
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def rand(n: int, r: random.Random):
|
||||||
|
return int(r.random() * n)
|
||||||
|
|
||||||
|
def transform(data, num_sample: int, r: random.Random):
|
||||||
|
if 'input' in data:
|
||||||
|
_input = data['input']
|
||||||
|
else:
|
||||||
|
_input = ""
|
||||||
|
|
||||||
|
if 'output' in data:
|
||||||
|
_output = data['output']
|
||||||
|
else:
|
||||||
|
_output = ""
|
||||||
|
return {"input": _input,
|
||||||
|
"output": _output,
|
||||||
|
}
|
|
@ -0,0 +1,134 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"dataset_name": "humanevallike_clean_dedup",
|
||||||
|
"task_name": "humanevallike_clean_dedup",
|
||||||
|
"abs_weight": 0.2,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 995339,
|
||||||
|
"ave_tokens_per_line": 100,
|
||||||
|
"total_tokens": 0.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "leetcode_pass_code_0125",
|
||||||
|
"task_name": "leetcode_pass_code_0125",
|
||||||
|
"abs_weight": 0.006,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 10724,
|
||||||
|
"ave_tokens_per_line": 200,
|
||||||
|
"total_tokens": 0.002
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "logiv2Annotate",
|
||||||
|
"task_name": "logiv2Annotate",
|
||||||
|
"abs_weight": 0.004,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 12566,
|
||||||
|
"ave_tokens_per_line": 512,
|
||||||
|
"total_tokens": 0.006
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "mmlu_enhance",
|
||||||
|
"task_name": "mmlu_enhance",
|
||||||
|
"abs_weight": 0.1,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 169771,
|
||||||
|
"ave_tokens_per_line": 300,
|
||||||
|
"total_tokens": 0.05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "mtbench_like",
|
||||||
|
"task_name": "mtbench_like",
|
||||||
|
"abs_weight": 0.2,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 319080,
|
||||||
|
"ave_tokens_per_line": 500,
|
||||||
|
"total_tokens": 0.15
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "ultra_dataset_new",
|
||||||
|
"task_name": "ultra_dataset_new",
|
||||||
|
"abs_weight": 2.0,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 385045,
|
||||||
|
"ave_tokens_per_line": 200.296266559615,
|
||||||
|
"total_tokens": 2.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "sft_data_zh_wowru",
|
||||||
|
"task_name": "sft_data_zh_wowru",
|
||||||
|
"abs_weight": 1.0,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 2963260,
|
||||||
|
"ave_tokens_per_line": 200.296266559615,
|
||||||
|
"total_tokens": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "math_data",
|
||||||
|
"task_name": "math_data",
|
||||||
|
"abs_weight": 0.003,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
||||||
|
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 2963260,
|
||||||
|
"ave_tokens_per_line": 200.296266559615,
|
||||||
|
"total_tokens": 0.005
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "t0",
|
||||||
|
"task_name": "t0",
|
||||||
|
"abs_weight": 0.1,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
||||||
|
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 1650309,
|
||||||
|
"ave_tokens_per_line": 500.296266559615,
|
||||||
|
"total_tokens": 0.82
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "wikihow",
|
||||||
|
"task_name": "wikihow",
|
||||||
|
"abs_weight": 0.1,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 180128,
|
||||||
|
"ave_tokens_per_line": 900.296266559615,
|
||||||
|
"total_tokens": 0.16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "reclor",
|
||||||
|
"task_name": "reclor",
|
||||||
|
"abs_weight": 0.002,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
||||||
|
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 4174,
|
||||||
|
"ave_tokens_per_line": 700.296266559615,
|
||||||
|
"total_tokens": 0.003
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_name": "logic_test_lx_0127",
|
||||||
|
"task_name": "logic_test_lx_0127",
|
||||||
|
"abs_weight": 0.001,
|
||||||
|
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
||||||
|
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||||
|
"allow_repeat": true,
|
||||||
|
"nlines": 2800,
|
||||||
|
"ave_tokens_per_line": 200.96266559615,
|
||||||
|
"total_tokens": 0.0004
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,27 @@
|
||||||
|
{
|
||||||
|
"vocab_size": 119696,
|
||||||
|
"dropout_p": 0.0,
|
||||||
|
"eps": 1e-05,
|
||||||
|
"half": true,
|
||||||
|
"half_type": "bf16",
|
||||||
|
"use_flash_attn": true,
|
||||||
|
"flash_attn_mask_shape": "2d",
|
||||||
|
"dim_model": 4096,
|
||||||
|
"dim_ff": 14336,
|
||||||
|
"dim_head": 128,
|
||||||
|
"num_heads": 32,
|
||||||
|
"num_kv_heads": 32,
|
||||||
|
"num_layers": 32,
|
||||||
|
"activate_fn": "silu",
|
||||||
|
"init_std": 0.10,
|
||||||
|
"scale": false,
|
||||||
|
"scale_emb": 12,
|
||||||
|
"scale_depth": -1,
|
||||||
|
"model_type": "fm9g",
|
||||||
|
"architectures": [
|
||||||
|
"FM9GForCausalLM"
|
||||||
|
],
|
||||||
|
"qk_norm": false,
|
||||||
|
"tie_lm_head": false,
|
||||||
|
"ffn_gated": true
|
||||||
|
}
|
|
@ -0,0 +1,568 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 ModelBest Inc.
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from bmtrain import nccl
|
||||||
|
from bmtrain.global_var import config as bmt_config
|
||||||
|
|
||||||
|
sys.path.append("../../")
|
||||||
|
from fm9g.arguments import get_args
|
||||||
|
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
||||||
|
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
||||||
|
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import CudaPrefetcher
|
||||||
|
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import MixedIndexedDataset
|
||||||
|
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import UnpadBatchedMixedDataset
|
||||||
|
from fm9g.utils import exporter
|
||||||
|
from fm9g.utils import logger
|
||||||
|
from fm9g.utils.exporter import save_every_step_stats
|
||||||
|
from fm9g.utils.training_stats import num_non_embedding_parameters
|
||||||
|
from fm9g.utils.training_stats import num_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(args):
|
||||||
|
from fm9g.tokenizer import FM9GTokenizer
|
||||||
|
tokenizer = FM9GTokenizer(path=args.vocab)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(args):
|
||||||
|
config = DragonflyConfig.from_json_file(args.model_config)
|
||||||
|
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
||||||
|
config.pose_prob = args.pose_prob
|
||||||
|
config.pose_scaling_factor = args.pose_scaling_factor
|
||||||
|
config.rope_scaling_type = args.rope_scaling_type
|
||||||
|
config.rope_scaling_factor = args.rope_scaling_factor
|
||||||
|
config.orig_max_length = args.orig_max_length
|
||||||
|
config.use_checkpoint = True if args.use_checkpoint == 1 else False
|
||||||
|
|
||||||
|
bmt.print_rank("model config: {}".format(config))
|
||||||
|
bmt.print_rank("bmt config: {}".format(bmt.config))
|
||||||
|
|
||||||
|
model = Dragonfly(config)
|
||||||
|
if args.load is not None:
|
||||||
|
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||||
|
exporter.load_model_ckpt(args, model)
|
||||||
|
else:
|
||||||
|
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||||
|
bmt.init_parameters(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer(args, model):
|
||||||
|
scale_lr_group = []
|
||||||
|
normal_group = []
|
||||||
|
scale_lr_group_name, normal_group_name = [], []
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
||||||
|
scale_lr_group.append(p)
|
||||||
|
scale_lr_group_name.append(n)
|
||||||
|
else:
|
||||||
|
normal_group.append(p)
|
||||||
|
normal_group_name.append(n)
|
||||||
|
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
||||||
|
param_groups = [
|
||||||
|
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
||||||
|
{"params": normal_group, "lr": args.lr},
|
||||||
|
]
|
||||||
|
|
||||||
|
if args.offload:
|
||||||
|
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||||
|
else:
|
||||||
|
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||||
|
if args.load is not None and args.load_grad:
|
||||||
|
exporter.load_optimizer_ckpt(args, optimizer)
|
||||||
|
bmt.print_rank("optimizer is loaded!")
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_learning_rate_scheduler(args, optimizer):
|
||||||
|
from fm9g.training_utils.lr_scheduler import Cosine
|
||||||
|
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
||||||
|
from fm9g.training_utils.lr_scheduler import WarmupStableExp
|
||||||
|
|
||||||
|
end_iter = args.train_iters
|
||||||
|
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
||||||
|
warmup_iters = int(end_iter * args.warmup_iters)
|
||||||
|
else:
|
||||||
|
warmup_iters = int(args.warmup_iters)
|
||||||
|
|
||||||
|
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
||||||
|
drop_iters = int(end_iter * args.drop_iters)
|
||||||
|
else:
|
||||||
|
drop_iters = int(args.drop_iters)
|
||||||
|
|
||||||
|
if args.lr_scheduler == "cosine":
|
||||||
|
lr_scheduler = Cosine(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=warmup_iters,
|
||||||
|
end_iter=end_iter, # 原来是lr_decay_iter
|
||||||
|
num_iter=args.start_step,
|
||||||
|
)
|
||||||
|
# lr_end_restart=args.lr_end_restart,
|
||||||
|
# resume_no_optimze=args.resume_no_optimze,
|
||||||
|
#)
|
||||||
|
elif args.lr_scheduler == "warmupstabledrop":
|
||||||
|
lr_scheduler = WarmupStableDrop(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=warmup_iters,
|
||||||
|
end_iter=end_iter, # 原来是lr_decay_iter
|
||||||
|
drop_iter=drop_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
resume_no_optimze=args.resume_no_optimze,
|
||||||
|
)
|
||||||
|
elif args.lr_scheduler == "warmupstableexp":
|
||||||
|
lr_scheduler = WarmupStableExp(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=warmup_iters,
|
||||||
|
drop_begin=args.drop_begin, # 原来是lr_decay_iter
|
||||||
|
drop_iter=drop_iters,
|
||||||
|
drop_rate=args.drop_rate,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
resume_no_optimze=args.resume_no_optimze,
|
||||||
|
)
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model_and_optimizer(args):
|
||||||
|
start = time.time()
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
model = get_model(args)
|
||||||
|
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
optimizer = get_optimizer(args, model)
|
||||||
|
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
return tokenizer, model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def resume_training(args):
|
||||||
|
ckpts = sorted(
|
||||||
|
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
||||||
|
reverse=True,
|
||||||
|
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
||||||
|
)
|
||||||
|
# find newest job
|
||||||
|
ckpts = sorted(
|
||||||
|
ckpts,
|
||||||
|
reverse=True,
|
||||||
|
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(ckpts) > 0:
|
||||||
|
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
||||||
|
args.load = ckpts[0]
|
||||||
|
# by default, do not load grad file
|
||||||
|
args.load_grad = False
|
||||||
|
args.start_step = 0
|
||||||
|
else:
|
||||||
|
# no ckpts, nothing we can do
|
||||||
|
os._exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
args = get_args(pretrain=True)
|
||||||
|
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
||||||
|
|
||||||
|
if args.save is not None:
|
||||||
|
os.makedirs(args.save, exist_ok=True)
|
||||||
|
if args.load is not None:
|
||||||
|
if args.only_load_model == 0:
|
||||||
|
if args.start_step == 0:
|
||||||
|
log_ckpt = exporter.load_log_ckpt(args)
|
||||||
|
if "iteration" in log_ckpt:
|
||||||
|
args.start_step = log_ckpt["iteration"]
|
||||||
|
else:
|
||||||
|
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
||||||
|
logger.info("Start from step {}".format(args.start_step))
|
||||||
|
elif args.only_load_model == 1:
|
||||||
|
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
logger.info("You do not load model")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def see_memory(detail=False):
|
||||||
|
if detail:
|
||||||
|
res = torch.cuda.memory_summary()
|
||||||
|
else:
|
||||||
|
res = (
|
||||||
|
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||||
|
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||||
|
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||||
|
)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def add_mem_time(info, mem_usage, tim_usage):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
bmt.synchronize()
|
||||||
|
mem_usage[info] = see_memory()
|
||||||
|
tim_usage[info] = time.time()
|
||||||
|
return mem_usage, tim_usage
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
||||||
|
# task_ids 可能有-1 来代表无效token
|
||||||
|
_task_num = task_num + 1
|
||||||
|
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
||||||
|
# gen masks
|
||||||
|
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
||||||
|
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
||||||
|
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
||||||
|
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
||||||
|
# calc loss and tokens
|
||||||
|
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||||
|
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||||
|
# return token-wise avg losses and tokens
|
||||||
|
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkAve:
|
||||||
|
def __init__(self, chunk_size=100):
|
||||||
|
self.ave_list = []
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
def record(self, time):
|
||||||
|
self.ave_list.append(time)
|
||||||
|
self.ave_list = self.ave_list[-self.chunk_size :]
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return sum(self.ave_list) / len(self.ave_list)
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain(
|
||||||
|
args,
|
||||||
|
tokenizer,
|
||||||
|
model: Dragonfly,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||||
|
):
|
||||||
|
ave_model_time = ChunkAve(chunk_size=100)
|
||||||
|
ave_iter_time = ChunkAve(chunk_size=100)
|
||||||
|
|
||||||
|
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||||
|
optim_manager = bmt.optim.OptimManager(
|
||||||
|
loss_scale=bmt.world_size(),
|
||||||
|
loss_scale_steps=args.loss_scale_steps,
|
||||||
|
loss_scale_factor=2,
|
||||||
|
max_loss_scale=bmt.world_size(),
|
||||||
|
min_loss_scale=bmt.world_size(),
|
||||||
|
)
|
||||||
|
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
start_step = args.start_step
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
import distutils.version # noqa: F401
|
||||||
|
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
if not os.path.exists(args.tensorboard):
|
||||||
|
os.makedirs(args.tensorboard)
|
||||||
|
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||||
|
|
||||||
|
if args.load is not None:
|
||||||
|
log_ckpt = exporter.load_log_ckpt(args)
|
||||||
|
else:
|
||||||
|
log_ckpt = {}
|
||||||
|
|
||||||
|
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
||||||
|
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
||||||
|
|
||||||
|
global_world_size = bmt.world_size()
|
||||||
|
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||||
|
mixed_indexed_dataset = MixedIndexedDataset(
|
||||||
|
cfg_path=args.dataset,
|
||||||
|
cfg_json_str=None,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=args.max_length,
|
||||||
|
nthreads=args.dataloader_num_threads,
|
||||||
|
prefetch_slice=args.dataloader_prefetch,
|
||||||
|
weight_by_size=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
||||||
|
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
||||||
|
|
||||||
|
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
batched_dataset,
|
||||||
|
batch_size=None,
|
||||||
|
collate_fn=lambda x: x,
|
||||||
|
num_workers=args.dataloader_num_workers,
|
||||||
|
prefetch_factor=args.dataloader_prefetch_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def dummy_generator():
|
||||||
|
while True:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
mixed_indexed_dataset = dummy_generator()
|
||||||
|
dataloader = mixed_indexed_dataset
|
||||||
|
|
||||||
|
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
||||||
|
|
||||||
|
bmt.print_rank("Preparing dataset done.")
|
||||||
|
|
||||||
|
# inspect at init
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
|
||||||
|
try:
|
||||||
|
mem_usage, tim_usage = {}, {}
|
||||||
|
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
||||||
|
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||||
|
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
||||||
|
|
||||||
|
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
logits = model(
|
||||||
|
input=data["inputs"],
|
||||||
|
cu_seqlens=data["cu_seqlens"],
|
||||||
|
max_seqlen=data["max_seqlen"],
|
||||||
|
position_ids=data["position_ids"],
|
||||||
|
)
|
||||||
|
#print("logits: ", logits)
|
||||||
|
|
||||||
|
# chunk targets and task_ids
|
||||||
|
data["targets"] = (
|
||||||
|
data["targets"]
|
||||||
|
.view(-1)
|
||||||
|
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||||
|
.view(data["targets"].shape[0], -1)
|
||||||
|
)
|
||||||
|
data["task_ids"] = (
|
||||||
|
data["task_ids"]
|
||||||
|
.view(-1)
|
||||||
|
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||||
|
.view(data["task_ids"].shape[0], -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
_target = data["targets"].view(-1)
|
||||||
|
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
||||||
|
_w = (_target != -100).int()
|
||||||
|
loss = non_reduced_loss.sum() / _w.sum().float()
|
||||||
|
|
||||||
|
global_loss = bmt.sum_loss(loss).item()
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
optim_manager.backward(loss)
|
||||||
|
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
||||||
|
grad_accum_init_time = tim_usage["init"]
|
||||||
|
|
||||||
|
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||||
|
optim_manager.step()
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||||
|
model_time = tim_usage["optim"] - grad_accum_init_time
|
||||||
|
ave_model_time.record(model_time)
|
||||||
|
else:
|
||||||
|
# dummy optim step
|
||||||
|
grad_norm = torch.Tensor([0.0]).cuda()
|
||||||
|
tim_usage["optim"] = tim_usage["backward"]
|
||||||
|
mem_usage["optim"] = mem_usage["backward"]
|
||||||
|
model_time = tim_usage["optim"] - tim_usage['init']
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
task_num = len(data["task_names"])
|
||||||
|
task_loss, task_token = get_task_loss_and_token(
|
||||||
|
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
||||||
|
)
|
||||||
|
task_loss_map: Dict[str, float] = {}
|
||||||
|
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
||||||
|
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
||||||
|
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
||||||
|
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
||||||
|
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
||||||
|
ave_task_loss = sum_task_loss / tot_task_token
|
||||||
|
for i in range(task_num):
|
||||||
|
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
||||||
|
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
||||||
|
|
||||||
|
local_total_rate = torch.Tensor(
|
||||||
|
[data["lengths"].float().mean() / (args.max_length * args.batch_size)]
|
||||||
|
).cuda()
|
||||||
|
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||||
|
global_token_pass += (
|
||||||
|
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
)
|
||||||
|
last_before_log_time = tim_usage["before_log"]
|
||||||
|
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
iter_time = tim_usage["before_log"] - last_before_log_time
|
||||||
|
|
||||||
|
ave_iter_time.record(iter_time)
|
||||||
|
|
||||||
|
train_info = {
|
||||||
|
"time": iter_time,
|
||||||
|
"iteration": iteration,
|
||||||
|
"loss": global_loss,
|
||||||
|
"lr": lr_scheduler.current_lr,
|
||||||
|
"token_max": local_total_rate,
|
||||||
|
"token_pass": global_token_pass,
|
||||||
|
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
||||||
|
"grad_norm": grad_norm.item(),
|
||||||
|
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||||
|
"task_loss": task_loss_map,
|
||||||
|
"total_task_token": global_total_task_token,
|
||||||
|
}
|
||||||
|
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
||||||
|
|
||||||
|
time_report_str = "{model_time:.2f}={forward_time:.2f}+{backward_time:.2f}+{optim_time:.2f}".format(model_time=model_time, forward_time=tim_usage['forward']-tim_usage['init'], backward_time=tim_usage['backward']-tim_usage['forward'], optim_time=tim_usage['optim'] - tim_usage['backward'])
|
||||||
|
bmt.print_rank(
|
||||||
|
(
|
||||||
|
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
||||||
|
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
||||||
|
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
||||||
|
).format(
|
||||||
|
iteration=iteration,
|
||||||
|
loss=global_loss,
|
||||||
|
lr=lr_scheduler.current_lr,
|
||||||
|
model_time=time_report_str,
|
||||||
|
iter_time=iter_time,
|
||||||
|
chunk_ave_time=ave_iter_time.get(),
|
||||||
|
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
||||||
|
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
||||||
|
grad_norm=grad_norm.item(),
|
||||||
|
global_token_pass=global_token_pass_str,
|
||||||
|
mem_usage=max([value for key, value in mem_usage.items()]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"task_loss:\t| "
|
||||||
|
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||||
|
+ " |"
|
||||||
|
)
|
||||||
|
|
||||||
|
if iteration % 10 == 0:
|
||||||
|
bmt.print_rank(
|
||||||
|
"task_tokens (B):\t| "
|
||||||
|
+ " | ".join(
|
||||||
|
[
|
||||||
|
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
||||||
|
for task_name, task_token in global_total_task_token.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ " |"
|
||||||
|
)
|
||||||
|
|
||||||
|
if iteration % args.inspect_iters == 0:
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
|
||||||
|
if args.log_dir is not None and bmt.rank() == 0:
|
||||||
|
if args.save is not None:
|
||||||
|
save_every_step_stats(train_info, args.save)
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||||
|
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||||
|
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||||
|
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||||
|
for task_name, loss in task_loss_map.items():
|
||||||
|
if not math.isnan(loss):
|
||||||
|
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||||
|
|
||||||
|
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||||
|
log_ckpt = {
|
||||||
|
"global_total_task_token": global_total_task_token,
|
||||||
|
"global_token_pass": global_token_pass,
|
||||||
|
"iteration": iteration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.save is not None and iteration % args.save_iters == 0:
|
||||||
|
exporter.export(
|
||||||
|
model,
|
||||||
|
mixed_indexed_dataset,
|
||||||
|
tokenizer,
|
||||||
|
optimizer,
|
||||||
|
iteration,
|
||||||
|
args,
|
||||||
|
log_ckpt=log_ckpt,
|
||||||
|
final_save=False,
|
||||||
|
async_save=args.async_save,
|
||||||
|
)
|
||||||
|
|
||||||
|
if iteration == args.train_iters and args.stop_when_end == 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"train loop err: {e}")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_k_and_b(number):
|
||||||
|
if number >= 1e9: # 大于或等于10亿
|
||||||
|
b_number = number / 1e9
|
||||||
|
return f"{b_number:.2f}B"
|
||||||
|
elif number >= 1e6: # 大于或等于1百万
|
||||||
|
k_number = number / 1e6
|
||||||
|
return f"{k_number:.2f}M"
|
||||||
|
elif number >= 1e3:
|
||||||
|
k_number = number / 1e3
|
||||||
|
return f"{k_number:.2f}K"
|
||||||
|
else:
|
||||||
|
return str(number)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
bmt.synchronize()
|
||||||
|
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||||
|
bmt.print_rank("finish loading")
|
||||||
|
bmt.print_rank(
|
||||||
|
"Number of parameter {}, Number of non-e parameter {}".format(
|
||||||
|
num_parameters(model), num_non_embedding_parameters(model)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bmt.print_rank("args: {}".format(args))
|
||||||
|
|
||||||
|
print("begining training")
|
||||||
|
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,234 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#export OMP_NUM_THREADS=16
|
||||||
|
|
||||||
|
declare -A args # Declare an associative array to store arguments and values
|
||||||
|
|
||||||
|
args["model_unique"]="8b_0702"
|
||||||
|
args["resume_ckpt"]=""
|
||||||
|
args["config"]="8b"
|
||||||
|
args["flash"]="cuda"
|
||||||
|
args["batch_size"]="1"
|
||||||
|
args["max_length"]="4096"
|
||||||
|
args["save_iters"]="500"
|
||||||
|
args["train_iters"]="10"
|
||||||
|
args["dataset_config"]="fm9g_sft"
|
||||||
|
args["local"]="False"
|
||||||
|
args["dataloader"]="indexed"
|
||||||
|
args["save"]="True"
|
||||||
|
args["dataloader_num_threads"]=1
|
||||||
|
args["dataloader_prefetch"]=2
|
||||||
|
args["dataloader_prefetch_factor"]=32
|
||||||
|
args["dataloader_num_workers"]=2
|
||||||
|
args["lr"]="1e-5"
|
||||||
|
args["warmup_iters"]="20"
|
||||||
|
args["drop_iters"]="0.1"
|
||||||
|
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
||||||
|
args["load_grad"]="False"
|
||||||
|
args["grad_ckpt_num"]="160"
|
||||||
|
args["exp_group"]=""
|
||||||
|
args["ignore_cuda_oom"]="1"
|
||||||
|
args["tensorboard_all_tasks"]="0"
|
||||||
|
args["stop_when_end"]="0"
|
||||||
|
args["only_run_dataloader"]="0"
|
||||||
|
args["eps"]="1e-6"
|
||||||
|
args["inspect_iters"]="100"
|
||||||
|
args["strict_state_dict"]="1"
|
||||||
|
args["only_load_model"]="1"
|
||||||
|
args["lr_scheduler"]="cosine"
|
||||||
|
args["resume_no_optimze"]="0"
|
||||||
|
args["tp_size"]="1"
|
||||||
|
args["parallel_load_datastate"]="16"
|
||||||
|
args["async_save"]="False"
|
||||||
|
args["load_dataloader_ckpt"]="0"
|
||||||
|
args["drop_begin"]="-1"
|
||||||
|
args["drop_rate"]="0.5"
|
||||||
|
args["use_checkpoint"]="1"
|
||||||
|
|
||||||
|
|
||||||
|
# Loop through the arguments
|
||||||
|
for ((i=1; i<=$#; i++)); do
|
||||||
|
arg="${!i}"
|
||||||
|
# Check if the argument starts with "--"
|
||||||
|
if [[ "$arg" == --* ]]; then
|
||||||
|
arg_name="${arg:2}" # Remove leading "--"
|
||||||
|
valueid=$((i+1))
|
||||||
|
# Get the value of the argument if it exists
|
||||||
|
if ((i+1 <= $#)); then
|
||||||
|
args["$arg_name"]="${!valueid}"
|
||||||
|
i=$((i+1)) # Skip the next argument (its value)
|
||||||
|
else
|
||||||
|
args["$arg_name"]="" # Set empty value if no value provided
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
||||||
|
while read -r key value; do
|
||||||
|
args["$key"]="$value"
|
||||||
|
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 用cmd arg 再更新一次
|
||||||
|
# Loop through the arguments
|
||||||
|
for ((i=1; i<=$#; i++)); do
|
||||||
|
arg="${!i}"
|
||||||
|
# Check if the argument starts with "--"
|
||||||
|
if [[ "$arg" == --* ]]; then
|
||||||
|
arg_name="${arg:2}" # Remove leading "--"
|
||||||
|
valueid=$((i+1))
|
||||||
|
|
||||||
|
# Get the value of the argument if it exists
|
||||||
|
if ((i+1 <= $#)); then
|
||||||
|
args["$arg_name"]="${!valueid}"
|
||||||
|
i=$((i+1)) # Skip the next argument (its value)
|
||||||
|
else
|
||||||
|
args["$arg_name"]="" # Set empty value if no value provided
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Print the values of the arguments
|
||||||
|
echo "----------- CMD args ----------"
|
||||||
|
for key in "${!args[@]}"; do
|
||||||
|
echo "$key: ${args[$key]}"
|
||||||
|
done
|
||||||
|
echo "--------- END CMD args --------"
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["flash"]} == "triton" ]]; then
|
||||||
|
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
||||||
|
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
||||||
|
echo "triton flash"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
||||||
|
# GPUS_PER_NODE=1
|
||||||
|
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["model_unique"]} == "" ]]; then
|
||||||
|
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
||||||
|
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
||||||
|
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
||||||
|
else
|
||||||
|
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
||||||
|
fi
|
||||||
|
echo "model_unique: "$MODEL_UNIQUE
|
||||||
|
|
||||||
|
# --------------- 运行参数 ---------------
|
||||||
|
|
||||||
|
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
||||||
|
OPTS+=" --batch-size ${args["batch_size"]}"
|
||||||
|
OPTS+=" --train-iters ${args["train_iters"]}"
|
||||||
|
OPTS+=" --save-iters ${args["save_iters"]}"
|
||||||
|
OPTS+=" --save-name fm9g_live_checkpoint"
|
||||||
|
OPTS+=" --max-length ${args["max_length"]}"
|
||||||
|
OPTS+=" --lr ${args["lr"]}"
|
||||||
|
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
||||||
|
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
||||||
|
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
||||||
|
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
||||||
|
OPTS+=" --offload"
|
||||||
|
OPTS+=" --vocab ./tokenizer/vocab.txt"
|
||||||
|
OPTS+=" --flash ${args["flash"]}"
|
||||||
|
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
||||||
|
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
||||||
|
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
||||||
|
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
||||||
|
OPTS+=" --eps ${args["eps"]}"
|
||||||
|
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
||||||
|
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
||||||
|
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
||||||
|
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
||||||
|
OPTS+=" --weight-decay 0.1"
|
||||||
|
OPTS+=" --tp-size ${args["tp_size"]}"
|
||||||
|
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
||||||
|
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
||||||
|
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
||||||
|
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
||||||
|
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
||||||
|
|
||||||
|
if [[ ${args["load_grad"]} == "True" ]]; then
|
||||||
|
OPTS+=" --load-grad"
|
||||||
|
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["async_save"]} == "True" ]]; then
|
||||||
|
OPTS+=" --async_save"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
||||||
|
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
||||||
|
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
||||||
|
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
||||||
|
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# --------------- 写文件路径 ---------------
|
||||||
|
## checkpoint
|
||||||
|
if [[ ${args["save"]} == "True" ]]; then
|
||||||
|
|
||||||
|
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
||||||
|
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
||||||
|
else
|
||||||
|
echo "won't save model"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
## logs,/local/logs 等价于 ./datalogs(软链)
|
||||||
|
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
||||||
|
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
||||||
|
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if [[ ${args["local"]} == "True" ]]; then
|
||||||
|
current_dir=$(pwd)
|
||||||
|
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||||
|
else
|
||||||
|
current_dir=$(pwd)
|
||||||
|
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||||
|
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
||||||
|
if [ "${args["resume_ckpt"]}" != "" ]; then
|
||||||
|
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
||||||
|
else
|
||||||
|
echo "No checkpoint to load"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
filename="pretrain_dragonfly"
|
||||||
|
|
||||||
|
if [[ ${args["local"]} == "True" ]]; then
|
||||||
|
PRETRAIN_ENTRY="$filename.py"
|
||||||
|
else
|
||||||
|
PRETRAIN_ENTRY="$filename.py"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
GPUS_PER_NODE=8
|
||||||
|
NNODES=1
|
||||||
|
RANK=0
|
||||||
|
MASTER_ENDPOINT=g3006
|
||||||
|
MASTER_PORT=12345
|
||||||
|
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||||
|
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||||
|
|
||||||
|
echo "-------final CMD is------"
|
||||||
|
echo "${CMD}"
|
||||||
|
echo "-------final CMD end------"
|
||||||
|
|
||||||
|
$CMD
|
|
@ -119687,8 +119687,8 @@
|
||||||
"𠳐"
|
"𠳐"
|
||||||
"𥻗"
|
"𥻗"
|
||||||
"𬉼"
|
"𬉼"
|
||||||
"<pad_0>"
|
"<|im_start|>"
|
||||||
"<pad_1>"
|
"<|im_end|>"
|
||||||
"<pad_2>"
|
"<pad_2>"
|
||||||
"<pad_3>"
|
"<pad_3>"
|
||||||
"<pad_4>"
|
"<pad_4>"
|
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"pretrain": {
|
||||||
|
"train_iters": 20000,
|
||||||
|
"batch_size": 1,
|
||||||
|
"max_length": 4096,
|
||||||
|
"n_gpus": 8,
|
||||||
|
"lr": 1e-5
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,18 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +22,8 @@ def add_model_config_args(parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group("model", "model configuration")
|
group = parser.add_argument_group("model", "model configuration")
|
||||||
group.add_argument("--model-config", type=str, help="model configuration file")
|
group.add_argument("--model-config", type=str, help="model configuration file")
|
||||||
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
|
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
|
||||||
|
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
|
||||||
|
# group.add_argument("--qk_norm", action="store_true", default=False, help="qk layernorm")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +48,13 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
help="Load the gradient states",
|
help="Load the gradient states",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--grad-ckpt-num",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="grad file num (only work when --load-grad from files less than world-size )",
|
||||||
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--load-start-step",
|
"--load-start-step",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
@ -66,7 +90,9 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
|
|
||||||
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
|
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
|
||||||
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
|
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
|
||||||
|
group.add_argument("--num-micro-batches", type=int, default=16)
|
||||||
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
|
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
|
||||||
|
group.add_argument("--grad-accum", type=int, default=1, help="gradient accum steps")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--train-iters",
|
"--train-iters",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -74,11 +100,14 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
help="total number of iterations to train over all training runs",
|
help="total number of iterations to train over all training runs",
|
||||||
)
|
)
|
||||||
group.add_argument("--max-length", type=int, default=512, help="max length of input")
|
group.add_argument("--max-length", type=int, default=512, help="max length of input")
|
||||||
|
group.add_argument("--min-length", type=int, default=None, help="only for speed test")
|
||||||
|
|
||||||
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
|
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
|
||||||
|
|
||||||
# Learning rate.
|
# Learning rate.
|
||||||
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
|
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
|
||||||
|
group.add_argument("--lr_scheduler", type=str, default="cosine", help=" learning rate scheduler")
|
||||||
|
|
||||||
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
|
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
|
||||||
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
|
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
|
||||||
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
|
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
|
||||||
|
@ -92,21 +121,85 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--lr-decay-style",
|
"--drop-iters",
|
||||||
type=str,
|
type=float,
|
||||||
default="noam",
|
default=0.01,
|
||||||
choices=["constant", "linear", "cosine", "exponential", "noam"],
|
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
||||||
help="learning rate decay function",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
|
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
|
||||||
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
|
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
|
||||||
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
|
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
|
||||||
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
|
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
|
||||||
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
|
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
|
||||||
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
|
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
|
||||||
group.add_argument("--tp", default=1, type=int, help="whether we use tensor parallelism")
|
group.add_argument("--use-jfs-data", action="store_true", help="whether we use juicefs dataset")
|
||||||
|
group.add_argument("--tp-size", default=1, type=int)
|
||||||
|
group.add_argument("--pp-size", default=1, type=int)
|
||||||
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
|
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
|
||||||
group.add_argument("--gradient-accumulation-steps", type=int, default=1, help="gradient accumulation steps")
|
group.add_argument("--dataloader_num_threads", default=3, type=int, help="Only useful in indexed dataest.")
|
||||||
|
group.add_argument("--dataloader_prefetch", default=200, type=int, help="Only useful in indexed dataest.")
|
||||||
|
group.add_argument("--dataloader_num_workers", default=4, type=int, help="Only useful in indexed dataest.")
|
||||||
|
group.add_argument("--dataloader_prefetch_factor", default=50, type=int, help="Only useful in indexed dataest.")
|
||||||
|
group.add_argument(
|
||||||
|
"--dataloader",
|
||||||
|
default="indexed",
|
||||||
|
type=str,
|
||||||
|
help="dataloader type, 'indexed' for indexed dataset, 'normal' for normal dataset",
|
||||||
|
)
|
||||||
|
group.add_argument("--stop_when_end", default=0, type=int, help="Whether to stop training when we reach end_iter")
|
||||||
|
group.add_argument(
|
||||||
|
"--data_len_threshold",
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help="If the average length of a sequence is less than this int, mean the sample is biased. ",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--only_run_dataloader", default=0, type=int, help="Whether to only run dataloader to check data. "
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--only_load_model", default=0, type=int, help="Whether to only load a model ckpt, without anything else."
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--load_dataloader_ckpt", default=1, type=int, help="Whether to only load a model ckpt, without anything else."
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--resume_no_optimze",
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help="The number of steps that does not add optimization after resume",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--parallel_load_datastate",
|
||||||
|
default=256,
|
||||||
|
type=int,
|
||||||
|
help="The number of parallel workers to load dataset state",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--async_save",
|
||||||
|
action="store_true",
|
||||||
|
help="whether to save artifacts asynchronously",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop_begin",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="The number of steps that starts to drop lr"
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop_rate",
|
||||||
|
default=0.5,
|
||||||
|
type=float,
|
||||||
|
help="The number rate"
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--use_checkpoint",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="Whether to use checkpointing."
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,6 +226,17 @@ def add_pretrain_args(parser: argparse.ArgumentParser):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_tokenizer_args(parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group("tokenizer", "tokenizer configurations")
|
||||||
|
group.add_argument(
|
||||||
|
"--tokenizer_path",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="tokenizer_path",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def add_finetune_args(parser: argparse.ArgumentParser):
|
def add_finetune_args(parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group("finetune", "finetune configurations")
|
group = parser.add_argument_group("finetune", "finetune configurations")
|
||||||
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
|
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
|
||||||
|
@ -204,14 +308,53 @@ def add_feedback_learning_args(parser: argparse.ArgumentParser):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def add_delta_args(parser: argparse.ArgumentParser):
|
def add_model_change_args(parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group("LoRA","LoRA configurations")
|
group = parser.add_argument_group("model_change", "model change during pretraining")
|
||||||
group.add_argument("--delta-type", type=str, default=None, help="delta-tuning-type")
|
group.add_argument("--strict_state_dict", type=int, default=1, help="strict_state_dict")
|
||||||
group.add_argument("--lora-r", type=int, default=8, help="lora-rank")
|
##
|
||||||
group.add_argument("--lora-alpha", type=int, default=8, help="lora-alpha")
|
return parser
|
||||||
group.add_argument("--lora-dropout", type=float, default=0.0, help="lora-dropout")
|
|
||||||
group.add_argument("--lora-layer", nargs='+', default=['project_q','project_k'], help="lora-layer")
|
|
||||||
group.add_argument("--save-origin-model", action="store_true", default=False)
|
def add_log_args(parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group("log", "log configurations")
|
||||||
|
group.add_argument("--tensorboard_all_tasks", type=int, default=0, help="log")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_error_handle_args(parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group("error_handle", "error_handle configurations")
|
||||||
|
group.add_argument(
|
||||||
|
"--ignore_cuda_oom", type=int, default=1, help="continue training by ingore the batch that causes oom"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_runtime_eval_args(parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group("runtime eval args", "runtime evaluation by submitting a job")
|
||||||
|
group.add_argument(
|
||||||
|
"--runtime_eval",
|
||||||
|
action="store_true",
|
||||||
|
help="whether to use runtime_eval. Only if this is set to True, the following variables will be useful",
|
||||||
|
)
|
||||||
|
group.add_argument("--eval_jeeves_auth", type=str, default="", help="auth, press f12 on jeeves platform to get")
|
||||||
|
group.add_argument("--eval_project_id", type=str, default=None, help="project id")
|
||||||
|
group.add_argument("--eval_run_cmd", type=str, default="", help="cmd for eval")
|
||||||
|
group.add_argument(
|
||||||
|
"--eval_git_path",
|
||||||
|
type=str,
|
||||||
|
default="git@git.in.zhihu.com:luca/llm-bench.git",
|
||||||
|
help="git path of evaluation code",
|
||||||
|
)
|
||||||
|
group.add_argument("--eval_git_branch", type=str, default="master", help="git branch of evaluation code")
|
||||||
|
group.add_argument("--eval_node_num", type=int, default=1, help="using 1 node to evaluate")
|
||||||
|
group.add_argument("--eval_gpu_num", type=int, default=1, help="using 1 gpu per node to evaluate")
|
||||||
|
group.add_argument("--eval_tasks_config", type=str, default="", help="evaluate tasks' config")
|
||||||
|
group.add_argument("--eval_model_backend", default="torch", type=str, help="model_backend")
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--eval_at_start", action="store_true", help="whether to eval at the first epoch, default to false"
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,6 +365,30 @@ def add_reward_args(parser: argparse.ArgumentParser):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_long_context_extend_args(parser: argparse.ArgumentParser):
|
||||||
|
"""long context extending arguments."""
|
||||||
|
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
|
||||||
|
group.add_argument("--pose-prob", default=0.0, type=float, help="Sample-level PoSE probability")
|
||||||
|
group.add_argument(
|
||||||
|
"--pose-scaling-factor",
|
||||||
|
default=1.0,
|
||||||
|
type=float,
|
||||||
|
help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--rope-scaling-type",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],
|
||||||
|
help="Context scaling type",
|
||||||
|
)
|
||||||
|
group.add_argument("--rope-scaling-factor", default=1, type=int, help="Context scaling factor")
|
||||||
|
group.add_argument(
|
||||||
|
"--orig-max-length", default=8192, type=int, help="Original context length before context extending"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_args(
|
def get_args(
|
||||||
pretrain: bool = False,
|
pretrain: bool = False,
|
||||||
finetune: bool = False,
|
finetune: bool = False,
|
||||||
|
@ -235,9 +402,14 @@ def get_args(
|
||||||
parser = add_training_args(parser)
|
parser = add_training_args(parser)
|
||||||
if pretrain:
|
if pretrain:
|
||||||
parser = add_pretrain_args(parser)
|
parser = add_pretrain_args(parser)
|
||||||
|
parser = add_runtime_eval_args(parser)
|
||||||
|
parser = add_tokenizer_args(parser)
|
||||||
|
parser = add_log_args(parser)
|
||||||
|
parser = add_error_handle_args(parser)
|
||||||
|
parser = add_model_change_args(parser)
|
||||||
|
|
||||||
if finetune:
|
if finetune:
|
||||||
parser = add_finetune_args(parser)
|
parser = add_finetune_args(parser)
|
||||||
parser = add_delta_args(parser)
|
|
||||||
if rhlf:
|
if rhlf:
|
||||||
parser = add_rhlf_args(parser)
|
parser = add_rhlf_args(parser)
|
||||||
if simple_rlhf:
|
if simple_rlhf:
|
||||||
|
@ -246,6 +418,7 @@ def get_args(
|
||||||
parser = add_feedback_learning_args(parser)
|
parser = add_feedback_learning_args(parser)
|
||||||
if reward:
|
if reward:
|
||||||
parser = add_reward_args(parser)
|
parser = add_reward_args(parser)
|
||||||
|
parser = add_long_context_extend_args(parser)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -4,7 +4,8 @@ from .distributed_dataset import SimpleDataset
|
||||||
from .indexed_dataset import IndexedDataset
|
from .indexed_dataset import IndexedDataset
|
||||||
from .indexed_dataset import IndexedDatasetBuilder
|
from .indexed_dataset import IndexedDatasetBuilder
|
||||||
from .indexed_dataset import PrefetchDecodeDataset
|
from .indexed_dataset import PrefetchDecodeDataset
|
||||||
from .list_dataset import ListDataset
|
|
||||||
|
# from .list_dataset import ListDataset
|
||||||
from .utils import compact_dataset
|
from .utils import compact_dataset
|
||||||
from .utils import CudaPrefetcher
|
from .utils import CudaPrefetcher
|
||||||
from .utils import mask_dataset
|
from .utils import mask_dataset
|
|
@ -1,3 +1,18 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
@ -281,7 +296,6 @@ class DistributedDataset:
|
||||||
info: List[FileInfo] = []
|
info: List[FileInfo] = []
|
||||||
if os.path.exists(meta_path):
|
if os.path.exists(meta_path):
|
||||||
info = _read_info_list(meta_path)
|
info = _read_info_list(meta_path)
|
||||||
|
|
||||||
old_len = len(self._file_info)
|
old_len = len(self._file_info)
|
||||||
if old_len > len(info):
|
if old_len > len(info):
|
||||||
raise RuntimeError("Dataset meta file: changed unexpectly")
|
raise RuntimeError("Dataset meta file: changed unexpectly")
|
||||||
|
@ -443,7 +457,11 @@ class DistributedDataset:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self._world_size > 1:
|
if self._world_size > 1:
|
||||||
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
||||||
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
|
max_unused_blocks = (
|
||||||
|
bmt.distributed.all_reduce(gpu_num_unused_block, op="max", comm=bmt.config["tp_zero_comm"])
|
||||||
|
.cpu()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
||||||
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
|
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
|
||||||
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
|
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
|
||||||
|
@ -452,9 +470,15 @@ class DistributedDataset:
|
||||||
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
).cuda()
|
).cuda()
|
||||||
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
|
global_states = bmt.distributed.all_gather(
|
||||||
global_offset = bmt.distributed.all_gather(gpu_offset).cpu() # (world_size, max_unused_blocks)
|
gpu_states, comm=bmt.config["tp_zero_comm"]
|
||||||
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
|
).cpu() # (world_size, max_unused_blocks)
|
||||||
|
global_offset = bmt.distributed.all_gather(
|
||||||
|
gpu_offset, comm=bmt.config["tp_zero_comm"]
|
||||||
|
).cpu() # (world_size, max_unused_blocks)
|
||||||
|
global_block = bmt.distributed.all_gather(
|
||||||
|
gpu_block, comm=bmt.config["tp_zero_comm"]
|
||||||
|
).cpu() # (world_size, 4)
|
||||||
return {"states": global_states, "offset": global_offset, "block": global_block}
|
return {"states": global_states, "offset": global_offset, "block": global_block}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
|
@ -1,13 +1,41 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||||
|
#
|
||||||
|
# @author: ouzebin <ouzebin@zhihu.com>
|
||||||
|
# @date: 2023/09/27
|
||||||
|
|
||||||
|
"""
|
||||||
|
使用 IndexedDataset 前需按指定格式构建或者转换已有数据集
|
||||||
|
数据集文件结构:
|
||||||
|
- <dataset name>
|
||||||
|
- data.jsonl # jsonl 格式的数据,每一行一条样本
|
||||||
|
- index # 记录每一行 json 数据的起始 byte-offset
|
||||||
|
|
||||||
|
从头构建:直接使用 IndexedDatasetBuilder 这个 context manager
|
||||||
|
>>> with IndexedDatasetBuilder("swear", overwrite=True) as builder:
|
||||||
|
>>> for data in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
|
||||||
|
>>> builder.put(data)
|
||||||
|
转换:
|
||||||
|
从 fm9g distributed_dataset 转换:使用 `fm9g.dataset.tools.distributed_to_indexed`
|
||||||
|
$ python -m fm9g.dataset.tools.distributed_to_indexed -i <原数据集文件夹> -o <新数据集文件夹>
|
||||||
|
已有 jsonl 数据:使用 `fm9g.dataset.tools.jsonl_to_index` 构建 index 文件。需提前先把 jsonl 文件命名为
|
||||||
|
$ python -m fm9g.dataset.tools.jsonl_to_index -p <数据集文件夹路径>
|
||||||
|
"""
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import queue
|
import queue
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
|
import h5py
|
||||||
|
import numpy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import msgspec
|
import msgspec
|
||||||
|
@ -22,13 +50,83 @@ except ModuleNotFoundError:
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from .utils import Range
|
from fm9g.utils.bitset import BitSet
|
||||||
|
from fm9g.utils.bitset import bitset_diff
|
||||||
|
|
||||||
print_lock = threading.Lock()
|
print_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def random_range(start, stop=None, step=None):
|
||||||
|
"""
|
||||||
|
Generator of non-repeated random permutation with the same inteface of python
|
||||||
|
`range`. Obtained from https://stackoverflow.com/a/53551417
|
||||||
|
The random.shuffle(list) and random.sample(list, len(list)) require
|
||||||
|
materialize the lists, which result in a long initalization period.
|
||||||
|
"""
|
||||||
|
if stop is None:
|
||||||
|
start, stop = 0, start
|
||||||
|
if step is None:
|
||||||
|
step = 1
|
||||||
|
# Use a mapping to convert a standard range into the desired range.
|
||||||
|
mapping = lambda i: (i * step) + start
|
||||||
|
# Compute the number of numbers in this range.
|
||||||
|
maximum = int(math.ceil((stop - start) / step))
|
||||||
|
if maximum == 0:
|
||||||
|
# early return with empty range
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
# Seed range with a random integer.
|
||||||
|
value = random.randint(0, maximum)
|
||||||
|
# Construct an offset, multiplier, and modulus for a linear
|
||||||
|
# congruential generator. These generators are cyclic and
|
||||||
|
# non-repeating when they maintain the properties:
|
||||||
|
#
|
||||||
|
# 1) "modulus" and "offset" are relatively prime.
|
||||||
|
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
|
||||||
|
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
|
||||||
|
|
||||||
|
# Pick a random odd-valued offset.
|
||||||
|
offset = random.randint(0, maximum) * 2 + 1
|
||||||
|
# Pick a multiplier 1 greater than a multiple of 4.
|
||||||
|
multiplier = 4 * (maximum // 4) + 1
|
||||||
|
# Pick a modulus just big enough to generate all numbers (power of 2).
|
||||||
|
modulus = int(2 ** math.ceil(math.log2(maximum)))
|
||||||
|
# Track how many random numbers have been returned.
|
||||||
|
found = 0
|
||||||
|
while found < maximum:
|
||||||
|
# If this is a valid value, yield it in generator fashion.
|
||||||
|
if value < maximum:
|
||||||
|
found += 1
|
||||||
|
yield mapping(value)
|
||||||
|
# Calculate the next value in the sequence.
|
||||||
|
value = (value * multiplier + offset) % modulus
|
||||||
|
|
||||||
|
|
||||||
|
class Range(object):
|
||||||
|
def __init__(self, start, stop, step):
|
||||||
|
self.start = start
|
||||||
|
self.stop = stop
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Range({self.start}, {self.stop}, {self.step})"
|
||||||
|
|
||||||
|
def iterate(self):
|
||||||
|
yield from range(self.start, self.stop, self.step)
|
||||||
|
|
||||||
|
def list(self):
|
||||||
|
return list(range(self.start, self.stop, self.step))
|
||||||
|
|
||||||
|
def subrange(self, split, nsplits):
|
||||||
|
# strided spliting range params
|
||||||
|
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
|
||||||
|
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
|
||||||
|
|
||||||
|
def random_iterate(self):
|
||||||
|
yield from random_range(self.start, self.stop, self.step)
|
||||||
|
|
||||||
|
|
||||||
def safe_print(*args, **kargs):
|
def safe_print(*args, **kargs):
|
||||||
if "flush" in kargs:
|
if "flush" in kargs:
|
||||||
flush = kargs["flush"]
|
flush = kargs["flush"]
|
||||||
|
@ -40,12 +138,15 @@ def safe_print(*args, **kargs):
|
||||||
|
|
||||||
|
|
||||||
def concurrent_info():
|
def concurrent_info():
|
||||||
world_size, rank = bmt.world_size(), bmt.rank()
|
# world_size, rank = bmt.world_size(), bmt.rank()
|
||||||
|
world_size = bmt.config["world_size"] // bmt.config["tp_size"]
|
||||||
|
rank = bmt.config["topology"].tp_idx
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
if worker_info is None:
|
if worker_info is None:
|
||||||
nworkers, worker_id = 1, 1
|
nworkers, worker_id = 1, 1
|
||||||
else:
|
else:
|
||||||
nworkers, worker_id = worker_info.num_workers, worker_info.id
|
nworkers, worker_id = worker_info.num_workers, worker_info.id
|
||||||
|
# print("concurrent_info: (world_size, rank, nworkers, worker_id): {}".format((world_size, rank, nworkers, worker_id)))
|
||||||
return world_size, rank, nworkers, worker_id
|
return world_size, rank, nworkers, worker_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,14 +157,45 @@ class IndexedDataset(Dataset):
|
||||||
self.max_retry = max_retry
|
self.max_retry = max_retry
|
||||||
self.retry_sleep = retry_sleep
|
self.retry_sleep = retry_sleep
|
||||||
self.bounds = None
|
self.bounds = None
|
||||||
|
self.h5file = None
|
||||||
self.build_index()
|
self.build_index()
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.bounds[-1]
|
return self.bounds[-1]
|
||||||
|
|
||||||
|
def _build_index_h5(self):
|
||||||
|
index_path = os.path.join(self.path, "index.h5")
|
||||||
|
if os.path.getsize(index_path) > 104857600:
|
||||||
|
self.h5file = h5py.File(os.path.join(self.path, "index.h5"), "r")
|
||||||
|
self.bounds = self.h5file["index"]
|
||||||
|
else:
|
||||||
|
# only load index into memory when it is small (< 100 Mb)
|
||||||
|
# to avoid keeping to many file handlers
|
||||||
|
self.h5file = None
|
||||||
|
with h5py.File(index_path, "r") as hf:
|
||||||
|
self.bounds = np.array(hf["index"])
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.h5file is not None:
|
||||||
|
self.h5file.close()
|
||||||
|
|
||||||
def build_index(self):
|
def build_index(self):
|
||||||
|
s = time.time()
|
||||||
|
|
||||||
|
txt_size = os.path.getsize(os.path.join(self.path, "index"))
|
||||||
|
if txt_size > 0.5 * 1024**3 and os.path.exists(os.path.join(self.path, "index.h5")):
|
||||||
|
source = "h5"
|
||||||
|
self._build_index_h5()
|
||||||
|
else:
|
||||||
|
source = "txt"
|
||||||
|
self._build_index_txt()
|
||||||
|
e = time.time()
|
||||||
|
bmt.print_rank("build_index_{} from {}, using {:.2f}s".format(source, self.path, e - s))
|
||||||
|
|
||||||
|
def _build_index_txt(self):
|
||||||
with open(os.path.join(self.path, "index"), "r") as fin:
|
with open(os.path.join(self.path, "index"), "r") as fin:
|
||||||
self.bounds = [int(line) for line in fin]
|
self.bounds = [int(line) for line in fin]
|
||||||
|
self.nlines = len(self.bounds)
|
||||||
|
|
||||||
def safe_read(self, i_or_s, offset, size):
|
def safe_read(self, i_or_s, offset, size):
|
||||||
for retry in itertools.count():
|
for retry in itertools.count():
|
||||||
|
@ -138,39 +270,10 @@ class IndexedDataset(Dataset):
|
||||||
class PrefetchDecodeDataset(IndexedDataset):
|
class PrefetchDecodeDataset(IndexedDataset):
|
||||||
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
|
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
|
||||||
# Add safe decoding in iterator
|
# Add safe decoding in iterator
|
||||||
def __init__(self, *args, decode=json_decode, **kargs):
|
def __init__(self, *args, decode=json_decode, allow_repeat=False, **kargs):
|
||||||
super().__init__(*args, **kargs)
|
super().__init__(*args, **kargs)
|
||||||
self.decode = decode
|
self.decode = decode
|
||||||
self.lock = threading.Lock()
|
self.allow_repeat = allow_repeat
|
||||||
self.prev_used = set() # store previously used index in the checkpoint
|
|
||||||
self.used = set() # track locally used index
|
|
||||||
|
|
||||||
def state_dict(self, gathered=True):
|
|
||||||
if not self.prev_used and not self.used:
|
|
||||||
return {"prev_used": set()}
|
|
||||||
if gathered:
|
|
||||||
used = torch.tensor(list(self.used)).cuda()
|
|
||||||
size = torch.tensor(used.numel()).cuda()
|
|
||||||
max_size = bmt.distributed.all_reduce(size, op="max")
|
|
||||||
# allgather requires tensors having the same size
|
|
||||||
used = torch.cat([used, torch.full((max_size - size,), -100, device=used.device)], dim=-1)
|
|
||||||
all_used = bmt.distributed.all_gather(used).unique()
|
|
||||||
all_used = set(all_used.tolist())
|
|
||||||
if -100 in all_used:
|
|
||||||
all_used.remove(-100) # remove the padding value
|
|
||||||
all_used.union(self.prev_used)
|
|
||||||
return {"prev_used": all_used}
|
|
||||||
else:
|
|
||||||
return {"prev_used": self.prev_used.union(self.used)}
|
|
||||||
|
|
||||||
def load_state_dict(self, state):
|
|
||||||
with self.lock:
|
|
||||||
self.used = state.get("prev_used", set())
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
with self.lock:
|
|
||||||
self.used = set()
|
|
||||||
self.prev_used = set()
|
|
||||||
|
|
||||||
def safe_decode(self, i, raw):
|
def safe_decode(self, i, raw):
|
||||||
if raw is None:
|
if raw is None:
|
||||||
|
@ -191,19 +294,23 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
else:
|
else:
|
||||||
return self.safe_decode(key, raw)
|
return self.safe_decode(key, raw)
|
||||||
|
|
||||||
def loader(self, q, lid, keys, stop):
|
def loader(self, q, lid, keys, stop, used=None):
|
||||||
# concurrent prefetching worker
|
# concurrent prefetching worker
|
||||||
|
if used is None:
|
||||||
|
used = BitSet()
|
||||||
try:
|
try:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if stop.is_set():
|
if stop.is_set():
|
||||||
break
|
break
|
||||||
# key is either a slice or an integer index
|
# key is either a slice or an integer index
|
||||||
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
|
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
|
||||||
with self.lock:
|
unused = bitset_diff(set(index), used)
|
||||||
unused = set(index) - self.used - self.prev_used
|
|
||||||
if not unused:
|
if not unused:
|
||||||
# skip used slice / item
|
# skip used slice / item
|
||||||
continue
|
continue
|
||||||
|
if not q.empty():
|
||||||
|
# avoid breaking the distributed file system with large io load
|
||||||
|
time.sleep(random.random() * 2)
|
||||||
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
|
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
|
||||||
raw = super().__getitem__(key)
|
raw = super().__getitem__(key)
|
||||||
if raw is None:
|
if raw is None:
|
||||||
|
@ -217,14 +324,14 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
# signaling the end of iteration to the main thread
|
# signaling the end of iteration to the main thread
|
||||||
q.put(StopIteration(lid))
|
q.put(StopIteration(lid))
|
||||||
|
|
||||||
def _iterate(self, key_groups, nprefetch=1000):
|
def _iterate(self, key_groups, nprefetch=1000, used=None):
|
||||||
# helper function for concurrent prefetching
|
# helper function for concurrent prefetching
|
||||||
q = queue.Queue(maxsize=nprefetch)
|
q = queue.Queue(maxsize=nprefetch)
|
||||||
stop = threading.Event()
|
stop = threading.Event()
|
||||||
alive = set()
|
alive = set()
|
||||||
try:
|
try:
|
||||||
for lid, keys in enumerate(key_groups):
|
for lid, keys in enumerate(key_groups):
|
||||||
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop), daemon=True)
|
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop, used), daemon=True)
|
||||||
loader.start()
|
loader.start()
|
||||||
alive.add(lid)
|
alive.add(lid)
|
||||||
while True:
|
while True:
|
||||||
|
@ -236,7 +343,7 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# new item will be put later, wait for a while
|
# new item will be put later, wait for a while
|
||||||
time.sleep(0.3)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
if isinstance(item, StopIteration):
|
if isinstance(item, StopIteration):
|
||||||
alive.remove(item.value)
|
alive.remove(item.value)
|
||||||
|
@ -245,16 +352,13 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
data = self.safe_decode(i, raw)
|
data = self.safe_decode(i, raw)
|
||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
self.used.add(i)
|
yield i, data
|
||||||
yield data
|
|
||||||
# automatically reset states with graceful ends.
|
|
||||||
self.reset()
|
|
||||||
finally:
|
finally:
|
||||||
# ask daemon loaders to stop
|
# ask daemon loaders to stop
|
||||||
stop.set()
|
stop.set()
|
||||||
|
|
||||||
def iterate(self, nthreads=3, prefetch_sample=100):
|
def iterate(self, nthreads=3, prefetch_sample=100, used=None, process_group=None):
|
||||||
world_size, rank, nworkers, worker_id = concurrent_info()
|
world_size, rank, nworkers, worker_id = concurrent_info(process_group)
|
||||||
nloaders = world_size * nworkers * nthreads
|
nloaders = world_size * nworkers * nthreads
|
||||||
if len(self) < nloaders:
|
if len(self) < nloaders:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -269,18 +373,27 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||||
# split index among multi-threaded loaders
|
# split index among multi-threaded loaders
|
||||||
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
|
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
|
||||||
for data in self._iterate(id_groups, nprefetch=prefetch_sample):
|
return self._iterate(id_groups, nprefetch=prefetch_sample, used=used)
|
||||||
yield data
|
|
||||||
|
|
||||||
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=1000):
|
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=500, used=None):
|
||||||
world_size, rank, nworkers, worker_id = concurrent_info()
|
world_size, rank, nworkers, worker_id = concurrent_info()
|
||||||
nloaders = world_size * nworkers * nthreads
|
nloaders = world_size * nworkers * nthreads
|
||||||
if len(self) < nloaders:
|
if len(self) < nloaders:
|
||||||
|
if not self.allow_repeat:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||||||
f"please constrain either "
|
f"please constrain either "
|
||||||
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
duplicated_factor = math.ceil(nloaders / len(self))
|
||||||
|
# In this case, slice size is 1
|
||||||
|
r = Range(0, len(self), 1)
|
||||||
|
# split index among grouped multi-gpu workers
|
||||||
|
r = r.subrange(split=rank // duplicated_factor, nsplits=math.ceil(world_size / duplicated_factor))
|
||||||
|
# # split index among multi-threaded loaders
|
||||||
|
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||||
|
else:
|
||||||
nslices = int(math.ceil(len(self) / slice_size))
|
nslices = int(math.ceil(len(self) / slice_size))
|
||||||
|
|
||||||
if nslices < nloaders:
|
if nslices < nloaders:
|
||||||
|
@ -300,20 +413,21 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
slice_groups = [
|
slice_groups = [
|
||||||
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
|
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
|
||||||
]
|
]
|
||||||
for data in self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size):
|
return self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size, used=used)
|
||||||
yield data
|
|
||||||
|
|
||||||
|
|
||||||
class IndexedDatasetBuilder:
|
class IndexedDatasetBuilder:
|
||||||
def __init__(self, path, overwrite=False):
|
def __init__(self, path, overwrite=False):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.index_path = os.path.join(self.path, "index")
|
self.index_path = os.path.join(self.path, "index.h5")
|
||||||
|
self.index_path_txt = os.path.join(self.path, "index")
|
||||||
self.data_path = os.path.join(self.path, "data.jsonl")
|
self.data_path = os.path.join(self.path, "data.jsonl")
|
||||||
if not overwrite:
|
if not overwrite:
|
||||||
assert not os.path.exists(self.data_path)
|
assert not os.path.exists(self.data_path)
|
||||||
assert not os.path.exists(self.index_path)
|
assert not os.path.exists(self.index_path)
|
||||||
|
assert not os.path.exists(self.index_path_txt)
|
||||||
self.fout = None
|
self.fout = None
|
||||||
self.starts = []
|
self.bounds = []
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -322,15 +436,17 @@ class IndexedDatasetBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.starts.append(self.offset)
|
self.bounds.append(self.offset)
|
||||||
with open(self.index_path, "w") as fout:
|
with h5py.File(os.path.join(self.index_path), "w") as hf:
|
||||||
for s in self.starts:
|
hf.create_dataset("index", data=self.bounds)
|
||||||
fout.write(f"{s}\n")
|
with open(self.index_path_txt, "w") as fout_txt:
|
||||||
|
for s in self.bounds:
|
||||||
|
fout_txt.write(f"{s}\n")
|
||||||
self.fout.close()
|
self.fout.close()
|
||||||
|
|
||||||
def put(self, data: dict):
|
def put(self, data: dict):
|
||||||
s = json_encode(data) + b"\n"
|
s = json_encode(data) + b"\n"
|
||||||
self.starts.append(self.offset)
|
self.bounds.append(self.offset)
|
||||||
self.offset += len(s)
|
self.offset += len(s)
|
||||||
self.fout.write(s)
|
self.fout.write(s)
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||||
|
#
|
||||||
|
# @author: ouzebin <ouzebin@zhihu.com>
|
||||||
|
# @date: 2023/08/07
|
|
@ -1,14 +1,23 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||||
|
#
|
||||||
|
# @author: ouzebin <ouzebin@zhihu.com>
|
||||||
|
# @date: 2023/07/27
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from cpm.dataset import SimpleDataset
|
from fm9g.dataset import SimpleDataset
|
||||||
from cpm.dataset.indexed_dataset import IndexedDatasetBuilder
|
from fm9g.dataset.indexed_dataset import IndexedDatasetBuilder
|
||||||
|
|
||||||
|
|
||||||
def convert_cpm_data(cpm_path, out_path):
|
def convert_fm9g_data(fm9g_path, out_path):
|
||||||
dataset = SimpleDataset(cpm_path, shuffle=False)
|
dataset = SimpleDataset(fm9g_path, shuffle=False)
|
||||||
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
|
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
|
||||||
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
|
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
|
||||||
builder.put(dataset.read())
|
builder.put(dataset.read())
|
||||||
|
@ -16,7 +25,7 @@ def convert_cpm_data(cpm_path, out_path):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--input", "-i", required=True, help="Data path in CPM format.")
|
parser.add_argument("--input", "-i", required=True, help="Data path in fm9g format.")
|
||||||
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
|
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_cpm_data(args.input, args.output)
|
convert_fm9g_data(args.input, args.output)
|
|
@ -1,3 +1,10 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||||
|
#
|
||||||
|
# @author: ouzebin <ouzebin@zhihu.com>
|
||||||
|
# @date: 2023/08/07
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
@ -251,7 +266,7 @@ def merge_dataset(dst: str, src: str):
|
||||||
_write_info_list(meta_path_dst, nw_info)
|
_write_info_list(meta_path_dst, nw_info)
|
||||||
|
|
||||||
|
|
||||||
def to_cpm(src_data, dst_path, dst_name):
|
def to_fm9g(src_data, dst_path, dst_name):
|
||||||
if not os.path.exists(dst_path):
|
if not os.path.exists(dst_path):
|
||||||
os.makedirs(dst_path)
|
os.makedirs(dst_path)
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
{
|
||||||
|
"folders": [
|
||||||
|
{
|
||||||
|
"path": "../.."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"settings": {}
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
import torch
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DragonflyConfig(PretrainedConfig):
|
||||||
|
model_type = "fm9g_dragonfly"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {
|
||||||
|
"num_key_value_heads": "num_kv_heads",
|
||||||
|
"hidden_act": "activate_fn",
|
||||||
|
"hidden_size": "dim_model",
|
||||||
|
"num_attention_heads": "num_heads",
|
||||||
|
"intermediate_size": "dim_ff",
|
||||||
|
"num_hidden_layers": "num_layers",
|
||||||
|
"vocab_size": "vocab_size",
|
||||||
|
"rms_norm_eps": "eps",
|
||||||
|
"scale_emb": "scale_emb",
|
||||||
|
"scale_depth": "scale_depth",
|
||||||
|
"scale": "scale",
|
||||||
|
"attention_scale": "attention_scale",
|
||||||
|
"qk_norm": "qk_norm",
|
||||||
|
"ffn_gated": "ffn_gated",
|
||||||
|
} # model specific to common
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=122753, # TODO: do we need to change to 122880 = 960 * 128?
|
||||||
|
dim_model=4096,
|
||||||
|
num_heads=32,
|
||||||
|
num_kv_heads=32,
|
||||||
|
dim_head=128,
|
||||||
|
dim_ff=11008,
|
||||||
|
num_layers=32,
|
||||||
|
dropout_p=0.0,
|
||||||
|
activate_fn="silu",
|
||||||
|
scale=False,
|
||||||
|
scale_emb: float = 1.0,
|
||||||
|
scale_depth: float = -1,
|
||||||
|
dim_model_base: int = 256,
|
||||||
|
eps=1e-5,
|
||||||
|
init_std=0.02,
|
||||||
|
dtype="bf16",
|
||||||
|
base=10000,
|
||||||
|
qk_norm=False,
|
||||||
|
tie_lm_head=False,
|
||||||
|
max_length=8192,
|
||||||
|
pose_prob=0.0,
|
||||||
|
pose_scaling_factor=1,
|
||||||
|
rope_scaling_type="",
|
||||||
|
rope_scaling_factor=1,
|
||||||
|
orig_max_length=8192,
|
||||||
|
tp=0,
|
||||||
|
use_checkpoint=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.dim_ff = dim_ff
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dropout_p = dropout_p
|
||||||
|
self.activate_fn = activate_fn
|
||||||
|
self.scale = scale
|
||||||
|
self.scale_emb = scale_emb
|
||||||
|
self._dtype = dtype
|
||||||
|
self.dim_model_base = dim_model_base
|
||||||
|
self.scale_depth = scale_depth
|
||||||
|
self.eps = eps
|
||||||
|
self.init_std = init_std
|
||||||
|
self.base = base
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
self.tie_lm_head = tie_lm_head
|
||||||
|
self.use_bfloat16 = True if self._dtype == "bf16" else False
|
||||||
|
self.pose_prob = pose_prob
|
||||||
|
self.pose_scaling_factor = pose_scaling_factor
|
||||||
|
self.rope_scaling_type = rope_scaling_type
|
||||||
|
self.rope_scaling_factor = rope_scaling_factor
|
||||||
|
self.max_length = max_length
|
||||||
|
self.orig_max_length = orig_max_length
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
print("use_checkpoint", self.use_checkpoint)
|
||||||
|
self.tp = tp
|
||||||
|
super().__init__(architectures=["fm9gDragonflyForCausalLM"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale_width(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
if self.scale:
|
||||||
|
return self.dim_model / self.dim_model_base
|
||||||
|
else:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(
|
||||||
|
self,
|
||||||
|
): # -> Any | None:
|
||||||
|
if self._dtype == "bf16":
|
||||||
|
return torch.bfloat16
|
||||||
|
elif self._dtype == "fp16":
|
||||||
|
return torch.half
|
||||||
|
elif self._dtype == "float32":
|
||||||
|
return torch.float
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
||||||
|
from .pretrain_indexed import MixedIndexedDataset
|
|
@ -0,0 +1,74 @@
|
||||||
|
import logging
|
||||||
|
from multiprocessing import Lock
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
from flask import jsonify
|
||||||
|
from flask import request
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# 获取 Werkzeug 日志记录器并设置日志级别
|
||||||
|
log = logging.getLogger("werkzeug")
|
||||||
|
log.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalAvgTokensStat(object):
|
||||||
|
def __init__(self, decay_factor: float = 0.98):
|
||||||
|
self._avg_tokens = {}
|
||||||
|
self.decay_factor = decay_factor
|
||||||
|
self.lock = Lock()
|
||||||
|
self.task_locks = {}
|
||||||
|
|
||||||
|
def set_avg_tokens(self, task_name, avg_tokens):
|
||||||
|
self._register_task_lock_helper(task_name)
|
||||||
|
with self.task_locks[task_name]:
|
||||||
|
self._avg_tokens[task_name] = avg_tokens
|
||||||
|
|
||||||
|
def update_avg_tokens_by_ema(self, task_name, length):
|
||||||
|
self._register_task_lock_helper(task_name)
|
||||||
|
with self.task_locks[task_name]:
|
||||||
|
if task_name in self._avg_tokens and self._avg_tokens[task_name] > 0:
|
||||||
|
self._avg_tokens[task_name] = self._avg_tokens[task_name] * self.decay_factor + length * (
|
||||||
|
1 - self.decay_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._avg_tokens[task_name] = length
|
||||||
|
|
||||||
|
def get_avg_tokens(self, task_name):
|
||||||
|
self._register_task_lock_helper(task_name)
|
||||||
|
with self.task_locks[task_name]:
|
||||||
|
return self._avg_tokens.get(task_name, -1)
|
||||||
|
|
||||||
|
def _register_task_lock_helper(self, task_name):
|
||||||
|
if task_name not in self.task_locks:
|
||||||
|
with self.lock:
|
||||||
|
if task_name not in self.task_locks:
|
||||||
|
self.task_locks[task_name] = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
global_avg_tokens_stat = GlobalAvgTokensStat()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/avg_tokens/<path:task_name>", methods=["GET"])
|
||||||
|
def get_avg_tokens(task_name):
|
||||||
|
global global_avg_tokens_stat
|
||||||
|
avg_tokens = global_avg_tokens_stat.get_avg_tokens(task_name)
|
||||||
|
return jsonify({"avg_tokens": avg_tokens})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/avg_tokens/<path:task_name>", methods=["POST"])
|
||||||
|
def set_avg_tokens(task_name):
|
||||||
|
global global_avg_tokens_stat
|
||||||
|
action = request.args.get("action", "update", type=str)
|
||||||
|
length = request.args.get("length", -1, type=int)
|
||||||
|
if action == "set":
|
||||||
|
global_avg_tokens_stat.set_avg_tokens(task_name, length)
|
||||||
|
elif action == "update":
|
||||||
|
global_avg_tokens_stat.update_avg_tokens_by_ema(task_name, length)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown action: {action}")
|
||||||
|
return jsonify({"status": "ok"})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(port=5000, debug=True)
|
|
@ -0,0 +1,826 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||||
|
#
|
||||||
|
# @author: ouzebin <ouzebin@zhihu.com>
|
||||||
|
# @date: 2023/09/27
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import ctypes
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections import OrderedDict
|
||||||
|
from multiprocessing import Lock
|
||||||
|
from multiprocessing import Process
|
||||||
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import Iterator
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Set
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from fm9g.dataset import PrefetchDecodeDataset
|
||||||
|
from fm9g.utils.bitset import BitSet
|
||||||
|
from fm9g.utils.vdc_sampling import van_der_corput
|
||||||
|
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
IGNORE_TGT = -100
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||||
|
if cfg_json_str is not None:
|
||||||
|
cfgs = json.loads(cfg_json_str)
|
||||||
|
else:
|
||||||
|
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||||
|
cfgs = json.load(fin)
|
||||||
|
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||||
|
|
||||||
|
path_dict = None
|
||||||
|
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||||
|
try:
|
||||||
|
with open(platform_config_path, "r") as f:
|
||||||
|
platform_cfg = json.load(f)
|
||||||
|
path_dict = platform_cfg["dataset_map"]
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||||
|
except Exception as e:
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||||
|
|
||||||
|
task_name2dataset_name = dict()
|
||||||
|
for idx, cfg in enumerate(cfgs):
|
||||||
|
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||||
|
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||||
|
# to be delibrately annoying :)
|
||||||
|
if cfg["task_name"] in task_name2dataset_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||||
|
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||||
|
)
|
||||||
|
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||||
|
|
||||||
|
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||||
|
# if path_dict is not None:
|
||||||
|
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||||
|
|
||||||
|
# dealing with optional configs
|
||||||
|
if "weight" in cfg:
|
||||||
|
assert isinstance(cfg["weight"], (float, int))
|
||||||
|
else:
|
||||||
|
cfg["weight"] = 1.0
|
||||||
|
|
||||||
|
if "oversize_rule" in cfg:
|
||||||
|
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||||
|
else:
|
||||||
|
cfg["oversize_rule"] = "segment"
|
||||||
|
|
||||||
|
if "transforms" in cfg:
|
||||||
|
assert isinstance(cfg["transforms"], str)
|
||||||
|
# dealing with relative path
|
||||||
|
if not cfg["transforms"].startswith("/"):
|
||||||
|
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||||
|
if not cfg["transforms"]:
|
||||||
|
cfg["transforms"] = None
|
||||||
|
else:
|
||||||
|
cfg["transforms"] = None
|
||||||
|
|
||||||
|
if "incontext_weight" in cfg:
|
||||||
|
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||||
|
else:
|
||||||
|
cfg["incontext_weight"] = [1.0]
|
||||||
|
cfg["id"] = idx
|
||||||
|
# dataset and iterator will be built
|
||||||
|
return cfgs
|
||||||
|
|
||||||
|
|
||||||
|
def data2ids(data, tokenizer, max_length):
|
||||||
|
text = "\n".join(
|
||||||
|
[
|
||||||
|
data.get("title", "").strip(),
|
||||||
|
data.get("question", "").strip(),
|
||||||
|
data.get("answer", "").strip(),
|
||||||
|
data.get("abstract", "").strip(),
|
||||||
|
data.get("text", "").strip(),
|
||||||
|
data.get("code", "").strip(),
|
||||||
|
]
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
# suppress the annoying warning from tokenizer
|
||||||
|
ids = (
|
||||||
|
[tokenizer.bos_token_id]
|
||||||
|
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||||
|
+ [tokenizer.eos_token_id]
|
||||||
|
)
|
||||||
|
src_ids = ids[0:-1]
|
||||||
|
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||||
|
|
||||||
|
if len(src_ids) > max_length:
|
||||||
|
for st in range(0, len(src_ids), max_length):
|
||||||
|
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||||
|
else:
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
|
||||||
|
|
||||||
|
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||||
|
assert oversize_rule in ("drop", "head", "segment")
|
||||||
|
if data is None:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
if "output" not in data or not data["output"]:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
if "input" not in data or data["input"] is None:
|
||||||
|
data["input"] = ""
|
||||||
|
|
||||||
|
src_ids = [tokenizer.bos_token_id]
|
||||||
|
tgt_ids = []
|
||||||
|
has_input = False
|
||||||
|
is_segment_reenter = False
|
||||||
|
|
||||||
|
# Use incremental tokenization to avoid waiting for a long document
|
||||||
|
MAX_CHUNK_LENGTH = max_length * 10
|
||||||
|
for part in ("input", "output"):
|
||||||
|
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||||
|
while l < len(data[part]):
|
||||||
|
try:
|
||||||
|
current_slice = data[part][l:r]
|
||||||
|
if not current_slice:
|
||||||
|
break
|
||||||
|
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||||
|
except:
|
||||||
|
print("Error in data[part][l:r] {}".format(data))
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
|
||||||
|
if part == "input":
|
||||||
|
if len(token_ids) > 0:
|
||||||
|
has_input = True
|
||||||
|
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||||
|
l = r
|
||||||
|
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||||
|
else:
|
||||||
|
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||||
|
if oversize_rule == "drop":
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
elif oversize_rule == "head":
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids[:-1])
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
return
|
||||||
|
elif oversize_rule == "segment":
|
||||||
|
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||||
|
if has_input: # is instruction data
|
||||||
|
if (
|
||||||
|
do_compact
|
||||||
|
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||||
|
and instruction_rest_space / len(src_ids) > 0.8
|
||||||
|
): # can be squeezed into max length
|
||||||
|
inputs_len = len(src_ids)
|
||||||
|
keep_len = instruction_rest_space // 2
|
||||||
|
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||||
|
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend(token_ids)
|
||||||
|
tgt_ids.append(tokenizer.eos_token_id)
|
||||||
|
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
else: # else use head rule
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids[:-1])
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
return
|
||||||
|
else: # normal segment
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids)
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||||
|
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
src_ids = src_ids[max_length:]
|
||||||
|
tgt_ids = tgt_ids[max_length:]
|
||||||
|
# sliding input str window
|
||||||
|
consumed_str = tokenizer.decode(selected_token_ids)
|
||||||
|
l += len(consumed_str)
|
||||||
|
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||||
|
is_segment_reenter = True
|
||||||
|
else:
|
||||||
|
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||||
|
not is_segment_reenter and len(token_ids) > 0
|
||||||
|
): # is segmented LM data
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend(token_ids)
|
||||||
|
tgt_ids.append(tokenizer.eos_token_id)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
else:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
tokenizer,
|
||||||
|
max_length=1024,
|
||||||
|
transform_func=None,
|
||||||
|
nthreads=1,
|
||||||
|
prefetch_slice=3,
|
||||||
|
slice_size=500,
|
||||||
|
do_compact=False,
|
||||||
|
):
|
||||||
|
super(SegmentedDataset, self).__init__()
|
||||||
|
self.segment = functools.partial(
|
||||||
|
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||||
|
)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.max_length = max_length
|
||||||
|
self.nthreads = nthreads
|
||||||
|
self.transform_func = transform_func
|
||||||
|
self.prefetch_slice = prefetch_slice
|
||||||
|
self.slice_size = slice_size
|
||||||
|
self.abs_weight = cfg.get("abs_weight", None)
|
||||||
|
self.task_name = cfg["task_name"]
|
||||||
|
self.dataset_name = cfg["dataset_name"]
|
||||||
|
self.oversize_rule = cfg["oversize_rule"]
|
||||||
|
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
|
||||||
|
self.exhausted = False
|
||||||
|
self.iterator = None
|
||||||
|
|
||||||
|
self.counter = 0
|
||||||
|
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||||
|
self.used = BitSet()
|
||||||
|
self.init_ave_tokens()
|
||||||
|
|
||||||
|
def init_ave_tokens(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||||
|
except FileNotFoundError:
|
||||||
|
bmt.print_rank(
|
||||||
|
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||||
|
)
|
||||||
|
shm = SharedMemory(
|
||||||
|
create=True,
|
||||||
|
size=ctypes.sizeof(ctypes.c_float),
|
||||||
|
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用共享内存
|
||||||
|
shared_value = ctypes.c_float.from_buffer(shm.buf)
|
||||||
|
_ave_tokens = self.cfg.get(
|
||||||
|
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ave_tokens > self.max_length:
|
||||||
|
_ave_tokens = self.max_length
|
||||||
|
bmt.print_rank(
|
||||||
|
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
|
||||||
|
_ave_tokens, self.max_length
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_value.value = _ave_tokens
|
||||||
|
# 不再需要 shared_value 时,删除引用
|
||||||
|
del shared_value
|
||||||
|
|
||||||
|
# 现在可以安全地关闭共享内存
|
||||||
|
shm.close()
|
||||||
|
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ave_tokens(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
existing_shm = SharedMemory(
|
||||||
|
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||||
|
) # -1 # default length
|
||||||
|
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||||
|
tmp = shared_value.value
|
||||||
|
del shared_value
|
||||||
|
existing_shm.close()
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
def ave_tokens_update(self, length):
|
||||||
|
existing_shm = SharedMemory(
|
||||||
|
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||||
|
) # -1 # default length
|
||||||
|
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||||
|
if shared_value.value < 0:
|
||||||
|
shared_value.value = float(length)
|
||||||
|
else:
|
||||||
|
shared_value.value = 0.98 * shared_value.value + 0.02 * length
|
||||||
|
del shared_value
|
||||||
|
existing_shm.close()
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.dataset.size()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.iterate()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.exhausted = False
|
||||||
|
if self.iterator is not None:
|
||||||
|
self.iterator.close()
|
||||||
|
self.iterator = None
|
||||||
|
self.used = BitSet()
|
||||||
|
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||||
|
|
||||||
|
def transform(self, data: dict) -> dict:
|
||||||
|
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||||
|
weight = weight / weight.sum()
|
||||||
|
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||||
|
return self.transform_func(data, num_incontext, random.Random())
|
||||||
|
|
||||||
|
def segment_iterate(self, sample_iter):
|
||||||
|
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||||
|
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||||
|
self.ave_tokens_update(len(src_ids)) # 0 for input ids
|
||||||
|
yield src_ids, tgt_ids, index
|
||||||
|
|
||||||
|
def iterate(self):
|
||||||
|
# make the dataset itself an iterator
|
||||||
|
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||||
|
self.iterator = self.segment_iterate(sample_iter)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
# advance the task iterator
|
||||||
|
if self.iterator is None:
|
||||||
|
self.iterate()
|
||||||
|
try:
|
||||||
|
return next(self.iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.exhausted = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
if state_dict.get("exhausted", False):
|
||||||
|
self.exhausted = True
|
||||||
|
self.used = BitSet()
|
||||||
|
else:
|
||||||
|
used = state_dict.get("used", BitSet())
|
||||||
|
if len(used) == len(self.dataset):
|
||||||
|
self.exhausted = True
|
||||||
|
self.used = BitSet()
|
||||||
|
else:
|
||||||
|
self.exhausted = False
|
||||||
|
self.used = used
|
||||||
|
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
if len(self.used) == len(self.dataset):
|
||||||
|
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
|
||||||
|
else:
|
||||||
|
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
|
||||||
|
|
||||||
|
def update_state(self, indice):
|
||||||
|
self.used.update(indice)
|
||||||
|
|
||||||
|
|
||||||
|
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg_path: str,
|
||||||
|
cfg_json_str,
|
||||||
|
tokenizer,
|
||||||
|
max_length,
|
||||||
|
weight_by_size: bool = True,
|
||||||
|
nthreads=5,
|
||||||
|
prefetch_slice=100,
|
||||||
|
parallel_loading=False,
|
||||||
|
vdc_sampling=False,
|
||||||
|
update_weights_frequency=1,
|
||||||
|
seed=42,
|
||||||
|
):
|
||||||
|
super(MixedIndexedDataset, self).__init__()
|
||||||
|
self.set_seed(seed + bmt.rank())
|
||||||
|
self.weight_by_size = weight_by_size
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
self.bos_token_id = self.tokenizer.bos_token_id
|
||||||
|
self.path2transform = dict()
|
||||||
|
self.task_dict = OrderedDict()
|
||||||
|
self.nthreads = nthreads
|
||||||
|
self.prefetch_slice = prefetch_slice
|
||||||
|
# useful for indexing
|
||||||
|
self.tasks = []
|
||||||
|
self.names = []
|
||||||
|
# ending of iteration
|
||||||
|
self.remain = 0
|
||||||
|
self.max_length = max_length
|
||||||
|
self.vdc_sampling = vdc_sampling
|
||||||
|
if self.vdc_sampling:
|
||||||
|
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
|
||||||
|
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||||
|
|
||||||
|
self.update_weights_frequency = update_weights_frequency
|
||||||
|
|
||||||
|
self.path2transform = dict()
|
||||||
|
|
||||||
|
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||||
|
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||||
|
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||||
|
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||||
|
|
||||||
|
if parallel_loading:
|
||||||
|
self.parallel_load(cfgs, max_workers=None)
|
||||||
|
else:
|
||||||
|
self.sequential_load(cfgs)
|
||||||
|
|
||||||
|
self.weights = None
|
||||||
|
self.update_weights()
|
||||||
|
|
||||||
|
def set_seed(self, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
def load_task(self, cfg):
|
||||||
|
logger.info(f"Loading {cfg['path']}")
|
||||||
|
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||||
|
task = SegmentedDataset(
|
||||||
|
cfg,
|
||||||
|
self.tokenizer,
|
||||||
|
self.max_length,
|
||||||
|
transform_func=transform_func,
|
||||||
|
nthreads=self.nthreads,
|
||||||
|
prefetch_slice=self.prefetch_slice,
|
||||||
|
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||||
|
)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def sequential_load(self, cfgs):
|
||||||
|
self.cfgs = cfgs
|
||||||
|
for cfg in cfgs:
|
||||||
|
# python3.7 and later preserves insertion order to dictionary
|
||||||
|
logger.info(f"Loading {cfg['path']}")
|
||||||
|
|
||||||
|
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||||
|
task = SegmentedDataset(
|
||||||
|
cfg,
|
||||||
|
self.tokenizer,
|
||||||
|
self.max_length,
|
||||||
|
transform_func=transform_func,
|
||||||
|
nthreads=self.nthreads,
|
||||||
|
prefetch_slice=self.prefetch_slice,
|
||||||
|
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||||
|
)
|
||||||
|
self.task_dict[task.task_name] = task
|
||||||
|
self.tasks.append(task)
|
||||||
|
self.names.append(task.task_name)
|
||||||
|
self.remain += 1
|
||||||
|
self.weights = None
|
||||||
|
self.update_weights()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
missing_keys = []
|
||||||
|
for name, task in self.task_dict.items():
|
||||||
|
if name in state_dict:
|
||||||
|
task.load_state_dict(state_dict[name])
|
||||||
|
else:
|
||||||
|
missing_keys.append(name)
|
||||||
|
self.update_weights()
|
||||||
|
return missing_keys
|
||||||
|
|
||||||
|
def save_state_dict(self, path):
|
||||||
|
state_dict = {}
|
||||||
|
for name, task in self.task_dict.items():
|
||||||
|
_state_dict = task.state_dict()
|
||||||
|
if isinstance(_state_dict["used"], BitSet):
|
||||||
|
bitset = _state_dict["used"]
|
||||||
|
_file_name = bitset.save(path)
|
||||||
|
_state_dict["used"] = _file_name
|
||||||
|
state_dict[name] = _state_dict
|
||||||
|
else:
|
||||||
|
state_dict[name] = task.state_dict()
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
logger.info("Dataset state saved")
|
||||||
|
|
||||||
|
def update_states(self, task_ids, indice):
|
||||||
|
is_dict = isinstance(indice, dict)
|
||||||
|
uniq = torch.unique(task_ids)
|
||||||
|
for idx in uniq:
|
||||||
|
idx = idx.item()
|
||||||
|
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||||
|
self.tasks[idx].update_state(indexes)
|
||||||
|
|
||||||
|
def get_transform_func(self, module_name: str, transform_script_path):
|
||||||
|
if transform_script_path is None:
|
||||||
|
# allow null transform
|
||||||
|
return lambda data, num_incontext, rand: data
|
||||||
|
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||||
|
if transform_script_path not in self.path2transform:
|
||||||
|
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||||
|
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||||
|
if spec is None:
|
||||||
|
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
self.path2transform[transform_script_path] = {
|
||||||
|
"loader": loader,
|
||||||
|
"module": mod,
|
||||||
|
"last_mtime": 0,
|
||||||
|
}
|
||||||
|
transform_script_info = self.path2transform[transform_script_path]
|
||||||
|
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||||
|
if curr_mtime > transform_script_info["last_mtime"]:
|
||||||
|
transform_script_info["last_mtime"] = curr_mtime
|
||||||
|
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||||
|
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||||
|
if transform_func is None:
|
||||||
|
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||||
|
return transform_func
|
||||||
|
|
||||||
|
def update_weights(self):
|
||||||
|
task0 = self.tasks[0]
|
||||||
|
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||||
|
weights = []
|
||||||
|
for task in self.tasks:
|
||||||
|
if task.exhausted:
|
||||||
|
weights.append(0)
|
||||||
|
else:
|
||||||
|
if task.ave_tokens == -1:
|
||||||
|
weights.append(task.abs_weight / self.max_length)
|
||||||
|
else:
|
||||||
|
weights.append(task.abs_weight / task.ave_tokens)
|
||||||
|
weights = np.array(weights)
|
||||||
|
else:
|
||||||
|
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||||
|
if self.weight_by_size:
|
||||||
|
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||||
|
weights *= sizes
|
||||||
|
self.weights = weights / weights.sum()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for task in self.tasks:
|
||||||
|
task.iterate()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
step = 1
|
||||||
|
while True:
|
||||||
|
if self.remain == 0:
|
||||||
|
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||||
|
raise StopIteration
|
||||||
|
if self.vdc_sampling:
|
||||||
|
idx = next(self.vdc_gen)(self.weights)
|
||||||
|
else:
|
||||||
|
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||||
|
|
||||||
|
data = next(self.tasks[idx])
|
||||||
|
if step % self.update_weights_frequency == 0:
|
||||||
|
self.update_weights()
|
||||||
|
if data is None:
|
||||||
|
if self.tasks[idx].allow_repeat:
|
||||||
|
# _runtime_ave = self.tasks[idx].ave_tokens
|
||||||
|
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||||
|
# self.tasks[idx] = SegmentedDataset(
|
||||||
|
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||||
|
# )
|
||||||
|
# self.tasks[idx].ave_tokens_update(_runtime_ave)
|
||||||
|
self.tasks[idx].reset()
|
||||||
|
else:
|
||||||
|
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||||
|
self.tasks[idx].exhaust = True
|
||||||
|
self.remain -= 1
|
||||||
|
continue
|
||||||
|
step += 1
|
||||||
|
return dict(
|
||||||
|
task_id=idx,
|
||||||
|
input=data[0],
|
||||||
|
target=data[1],
|
||||||
|
index=data[2],
|
||||||
|
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||||
|
self.max_total_length = batch_size * max_length
|
||||||
|
self.batch_size = 1
|
||||||
|
# setting compact=True concats segments orignated from the same input
|
||||||
|
# into a long sequence. the relative order of segments should be preserved
|
||||||
|
# in mixed_dataset, e.g.,
|
||||||
|
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||||
|
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||||
|
self.compact = compact
|
||||||
|
|
||||||
|
self.total_length = 0
|
||||||
|
self.task2seqs = defaultdict(list)
|
||||||
|
self.mixed_dataset = mixed_dataset
|
||||||
|
self._max_length = max_length
|
||||||
|
self._pose_prob = pose_prob
|
||||||
|
self._pose_scaling_factor = pose_scaling_factor
|
||||||
|
if self._pose_prob > 0.0:
|
||||||
|
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||||
|
else:
|
||||||
|
self._scaled_max_length = max_length
|
||||||
|
|
||||||
|
def put(self, sample):
|
||||||
|
self.total_length += len(sample["target"])
|
||||||
|
task_id = sample["task_id"]
|
||||||
|
if self.compact and self.task2seqs[task_id]:
|
||||||
|
last = self.task2seqs[task_id][-1]
|
||||||
|
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||||
|
# concatenate sequantial segments for longer context modeling: why not?
|
||||||
|
last["input"].extend(sample["input"])
|
||||||
|
last["target"].extend(sample["target"])
|
||||||
|
return
|
||||||
|
self.task2seqs[task_id].append(sample)
|
||||||
|
|
||||||
|
def _pose_preprocess(
|
||||||
|
self,
|
||||||
|
input_ids: NDArray[np.int32],
|
||||||
|
) -> NDArray[np.int32]:
|
||||||
|
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||||
|
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||||
|
"""
|
||||||
|
len_chunk = min(len(input_ids), self._max_length)
|
||||||
|
len_input = len(input_ids)
|
||||||
|
# Chunk input randomly to fit max_length if needed
|
||||||
|
lt1 = 0
|
||||||
|
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||||
|
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||||
|
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||||
|
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||||
|
# Generate PoSE position ids
|
||||||
|
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||||
|
len_position_ids = len(position_ids)
|
||||||
|
lt = 0
|
||||||
|
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||||
|
position_ids[: rt1 - lt1] += lt
|
||||||
|
position_ids[rt1 - lt1 :] += rt
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
indexes = defaultdict(list)
|
||||||
|
lengths = []
|
||||||
|
|
||||||
|
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||||
|
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||||
|
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||||
|
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||||
|
|
||||||
|
span_begin = 0
|
||||||
|
for samples in self.task2seqs.values():
|
||||||
|
while samples:
|
||||||
|
sample = samples.pop()
|
||||||
|
span_end = span_begin + len(sample["input"])
|
||||||
|
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||||
|
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||||
|
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||||
|
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||||
|
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||||
|
else:
|
||||||
|
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||||
|
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||||
|
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||||
|
lengths.append(len(sample["target"]))
|
||||||
|
indexes[int(sample["task_id"])].append(sample["index"])
|
||||||
|
self.total_length -= len(sample["target"])
|
||||||
|
span_begin = span_end
|
||||||
|
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||||
|
dim=0,
|
||||||
|
).int()
|
||||||
|
batch = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"targets": targets,
|
||||||
|
"task_ids": task_ids,
|
||||||
|
"indexes": indexes,
|
||||||
|
# adhere to flash attention interface
|
||||||
|
"cu_seqlens": cu_seqlens,
|
||||||
|
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||||
|
"lengths": torch.tensor(sum(lengths)).int(),
|
||||||
|
"task_names": self.mixed_dataset.names,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def will_be_full(self, sample):
|
||||||
|
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for sample in self.mixed_dataset:
|
||||||
|
if self.will_be_full(sample):
|
||||||
|
yield self.pop()
|
||||||
|
self.put(sample)
|
||||||
|
|
||||||
|
|
||||||
|
class CudaPrefetcher(Iterable):
|
||||||
|
"""
|
||||||
|
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||||
|
self.loader = iter(loader)
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.stream = torch.cuda.Stream()
|
||||||
|
self.preload()
|
||||||
|
|
||||||
|
def preload(self):
|
||||||
|
try:
|
||||||
|
if self.tp_size > 1:
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
data = next(self.loader)
|
||||||
|
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||||
|
d = {}
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||||
|
for key in data.keys():
|
||||||
|
if isinstance(data[key], torch.Tensor):
|
||||||
|
np_cur_data = data[key].cpu().numpy()
|
||||||
|
bs = np_cur_data.tobytes()
|
||||||
|
fb.write(bs)
|
||||||
|
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||||
|
elif isinstance(data[key], np.ndarray):
|
||||||
|
bs = data[key].tobytes()
|
||||||
|
fb.write(bs)
|
||||||
|
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||||
|
else:
|
||||||
|
d[key] = data[key]
|
||||||
|
try:
|
||||||
|
_ = json.dumps(d)
|
||||||
|
except TypeError:
|
||||||
|
print(d)
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||||
|
json.dump(d, f)
|
||||||
|
bmt.synchronize()
|
||||||
|
if self.tp_rank != 0:
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||||
|
bs = fb.read()
|
||||||
|
offset = 0
|
||||||
|
for key in data.keys():
|
||||||
|
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||||
|
nw_offset = offset + data[key][2]
|
||||||
|
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||||
|
data[key][3:]
|
||||||
|
)
|
||||||
|
offset = nw_offset
|
||||||
|
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||||
|
nw_offset = offset + data[key][2]
|
||||||
|
data[key] = torch.from_numpy(
|
||||||
|
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||||
|
.reshape(data[key][3:])
|
||||||
|
.copy()
|
||||||
|
)
|
||||||
|
offset = nw_offset
|
||||||
|
self.data = data
|
||||||
|
else:
|
||||||
|
self.data = next(self.loader)
|
||||||
|
except StopIteration:
|
||||||
|
self.data = None
|
||||||
|
return
|
||||||
|
with torch.cuda.stream(self.stream):
|
||||||
|
for key in self.data.keys():
|
||||||
|
if isinstance(self.data[key], torch.Tensor):
|
||||||
|
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
torch.cuda.current_stream().wait_stream(self.stream)
|
||||||
|
for key in self.data.keys():
|
||||||
|
if isinstance(self.data[key], torch.Tensor):
|
||||||
|
self.data[key].record_stream(torch.cuda.current_stream())
|
||||||
|
data = copy.deepcopy(self.data)
|
||||||
|
self.preload()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
|
@ -0,0 +1,828 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||||
|
#
|
||||||
|
# @author: ouzebin <ouzebin@zhihu.com>
|
||||||
|
# @date: 2023/09/27
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import ctypes
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections import OrderedDict
|
||||||
|
from multiprocessing import Lock
|
||||||
|
from multiprocessing import Process
|
||||||
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import Iterator
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Set
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from fm9g.dataset import PrefetchDecodeDataset
|
||||||
|
from fm9g.utils.bitset import BitSet
|
||||||
|
from fm9g.utils.vdc_sampling import van_der_corput
|
||||||
|
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
IGNORE_TGT = -100
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||||
|
if cfg_json_str is not None:
|
||||||
|
cfgs = json.loads(cfg_json_str)
|
||||||
|
else:
|
||||||
|
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||||
|
cfgs = json.load(fin)
|
||||||
|
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||||
|
|
||||||
|
path_dict = None
|
||||||
|
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||||
|
try:
|
||||||
|
with open(platform_config_path, "r") as f:
|
||||||
|
platform_cfg = json.load(f)
|
||||||
|
path_dict = platform_cfg["dataset_map"]
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||||
|
except Exception as e:
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||||
|
|
||||||
|
task_name2dataset_name = dict()
|
||||||
|
for idx, cfg in enumerate(cfgs):
|
||||||
|
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||||
|
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||||
|
# to be delibrately annoying :)
|
||||||
|
if cfg["task_name"] in task_name2dataset_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||||
|
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||||
|
)
|
||||||
|
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||||
|
|
||||||
|
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||||
|
# if path_dict is not None:
|
||||||
|
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||||
|
|
||||||
|
# dealing with optional configs
|
||||||
|
if "weight" in cfg:
|
||||||
|
assert isinstance(cfg["weight"], (float, int))
|
||||||
|
else:
|
||||||
|
cfg["weight"] = 1.0
|
||||||
|
|
||||||
|
if "oversize_rule" in cfg:
|
||||||
|
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||||
|
else:
|
||||||
|
cfg["oversize_rule"] = "segment"
|
||||||
|
|
||||||
|
if "transforms" in cfg:
|
||||||
|
assert isinstance(cfg["transforms"], str)
|
||||||
|
# dealing with relative path
|
||||||
|
if not cfg["transforms"].startswith("/"):
|
||||||
|
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||||
|
if not cfg["transforms"]:
|
||||||
|
cfg["transforms"] = None
|
||||||
|
else:
|
||||||
|
cfg["transforms"] = None
|
||||||
|
|
||||||
|
if "incontext_weight" in cfg:
|
||||||
|
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||||
|
else:
|
||||||
|
cfg["incontext_weight"] = [1.0]
|
||||||
|
cfg["id"] = idx
|
||||||
|
# dataset and iterator will be built
|
||||||
|
return cfgs
|
||||||
|
|
||||||
|
|
||||||
|
def data2ids(data, tokenizer, max_length):
|
||||||
|
text = "\n".join(
|
||||||
|
[
|
||||||
|
data.get("title", "").strip(),
|
||||||
|
data.get("question", "").strip(),
|
||||||
|
data.get("answer", "").strip(),
|
||||||
|
data.get("abstract", "").strip(),
|
||||||
|
data.get("text", "").strip(),
|
||||||
|
data.get("code", "").strip(),
|
||||||
|
]
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
# suppress the annoying warning from tokenizer
|
||||||
|
ids = (
|
||||||
|
[tokenizer.bos_id]
|
||||||
|
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||||
|
+ [tokenizer.eos_id]
|
||||||
|
)
|
||||||
|
src_ids = ids[0:-1]
|
||||||
|
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||||
|
|
||||||
|
if len(src_ids) > max_length:
|
||||||
|
for st in range(0, len(src_ids), max_length):
|
||||||
|
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||||
|
else:
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
|
||||||
|
|
||||||
|
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||||
|
assert oversize_rule in ("drop", "head", "segment")
|
||||||
|
if data is None:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
if "output" not in data or not data["output"]:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
if "input" not in data or data["input"] is None:
|
||||||
|
data["input"] = ""
|
||||||
|
|
||||||
|
src_ids = [tokenizer.bos_id]
|
||||||
|
tgt_ids = []
|
||||||
|
has_input = False
|
||||||
|
is_segment_reenter = False
|
||||||
|
|
||||||
|
# Use incremental tokenization to avoid waiting for a long document
|
||||||
|
MAX_CHUNK_LENGTH = max_length * 10
|
||||||
|
for part in ("input", "output"):
|
||||||
|
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||||
|
while l < len(data[part]):
|
||||||
|
try:
|
||||||
|
current_slice = data[part][l:r]
|
||||||
|
if not current_slice:
|
||||||
|
break
|
||||||
|
#token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||||
|
token_ids = tokenizer.encode(current_slice)
|
||||||
|
except:
|
||||||
|
#print("Error in data[part][l:r] {}".format(data))
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
|
||||||
|
if part == "input":
|
||||||
|
if len(token_ids) > 0:
|
||||||
|
has_input = True
|
||||||
|
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||||
|
l = r
|
||||||
|
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||||
|
else:
|
||||||
|
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||||
|
if oversize_rule == "drop":
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
elif oversize_rule == "head":
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids[:-1])
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
return
|
||||||
|
elif oversize_rule == "segment":
|
||||||
|
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||||
|
if has_input: # is instruction data
|
||||||
|
if (
|
||||||
|
do_compact
|
||||||
|
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||||
|
and instruction_rest_space / len(src_ids) > 0.8
|
||||||
|
): # can be squeezed into max length
|
||||||
|
inputs_len = len(src_ids)
|
||||||
|
keep_len = instruction_rest_space // 2
|
||||||
|
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||||
|
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend(token_ids)
|
||||||
|
tgt_ids.append(tokenizer.eos_id)
|
||||||
|
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
else: # else use head rule
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids[:-1])
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
return
|
||||||
|
else: # normal segment
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids)
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||||
|
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
src_ids = src_ids[max_length:]
|
||||||
|
tgt_ids = tgt_ids[max_length:]
|
||||||
|
# sliding input str window
|
||||||
|
consumed_str = tokenizer.decode(selected_token_ids)
|
||||||
|
l += len(consumed_str)
|
||||||
|
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||||
|
is_segment_reenter = True
|
||||||
|
else:
|
||||||
|
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||||
|
not is_segment_reenter and len(token_ids) > 0
|
||||||
|
): # is segmented LM data
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend(token_ids)
|
||||||
|
tgt_ids.append(tokenizer.eos_id)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
else:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
tokenizer,
|
||||||
|
max_length=1024,
|
||||||
|
transform_func=None,
|
||||||
|
nthreads=1,
|
||||||
|
prefetch_slice=3,
|
||||||
|
slice_size=500,
|
||||||
|
do_compact=False,
|
||||||
|
):
|
||||||
|
super(SegmentedDataset, self).__init__()
|
||||||
|
self.segment = functools.partial(
|
||||||
|
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||||
|
)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.max_length = max_length
|
||||||
|
self.nthreads = nthreads
|
||||||
|
self.transform_func = transform_func
|
||||||
|
self.prefetch_slice = prefetch_slice
|
||||||
|
self.slice_size = slice_size
|
||||||
|
self.abs_weight = cfg.get("abs_weight", None)
|
||||||
|
self.task_name = cfg["task_name"]
|
||||||
|
self.dataset_name = cfg["dataset_name"]
|
||||||
|
self.oversize_rule = cfg["oversize_rule"]
|
||||||
|
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
|
||||||
|
self.exhausted = False
|
||||||
|
self.iterator = None
|
||||||
|
|
||||||
|
self.counter = 0
|
||||||
|
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||||
|
self.used = BitSet()
|
||||||
|
self.init_ave_tokens()
|
||||||
|
|
||||||
|
def init_ave_tokens(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||||
|
except FileNotFoundError:
|
||||||
|
bmt.print_rank(
|
||||||
|
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||||
|
)
|
||||||
|
shm = SharedMemory(
|
||||||
|
create=True,
|
||||||
|
size=ctypes.sizeof(ctypes.c_float),
|
||||||
|
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用共享内存
|
||||||
|
shared_value = ctypes.c_float.from_buffer(shm.buf)
|
||||||
|
_ave_tokens = self.cfg.get(
|
||||||
|
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ave_tokens > self.max_length:
|
||||||
|
_ave_tokens = self.max_length
|
||||||
|
bmt.print_rank(
|
||||||
|
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
|
||||||
|
_ave_tokens, self.max_length
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_value.value = _ave_tokens
|
||||||
|
# 不再需要 shared_value 时,删除引用
|
||||||
|
del shared_value
|
||||||
|
|
||||||
|
# 现在可以安全地关闭共享内存
|
||||||
|
shm.close()
|
||||||
|
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ave_tokens(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
existing_shm = SharedMemory(
|
||||||
|
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||||
|
) # -1 # default length
|
||||||
|
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||||
|
tmp = shared_value.value
|
||||||
|
del shared_value
|
||||||
|
existing_shm.close()
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
def ave_tokens_update(self, length):
|
||||||
|
existing_shm = SharedMemory(
|
||||||
|
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||||
|
) # -1 # default length
|
||||||
|
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||||
|
if shared_value.value < 0:
|
||||||
|
shared_value.value = float(length)
|
||||||
|
else:
|
||||||
|
shared_value.value = 0.98 * shared_value.value + 0.02 * length
|
||||||
|
del shared_value
|
||||||
|
existing_shm.close()
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.dataset.size()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.iterate()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.exhausted = False
|
||||||
|
if self.iterator is not None:
|
||||||
|
self.iterator.close()
|
||||||
|
self.iterator = None
|
||||||
|
self.used = BitSet()
|
||||||
|
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||||
|
|
||||||
|
def transform(self, data: dict) -> dict:
|
||||||
|
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||||
|
weight = weight / weight.sum()
|
||||||
|
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||||
|
return self.transform_func(data, num_incontext, random.Random())
|
||||||
|
|
||||||
|
def segment_iterate(self, sample_iter):
|
||||||
|
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||||
|
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||||
|
self.ave_tokens_update(len(src_ids)) # 0 for input ids
|
||||||
|
yield src_ids, tgt_ids, index
|
||||||
|
|
||||||
|
def iterate(self):
|
||||||
|
# make the dataset itself an iterator
|
||||||
|
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||||
|
self.iterator = self.segment_iterate(sample_iter)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
# advance the task iterator
|
||||||
|
if self.iterator is None:
|
||||||
|
self.iterate()
|
||||||
|
try:
|
||||||
|
return next(self.iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.exhausted = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
if state_dict.get("exhausted", False):
|
||||||
|
self.exhausted = True
|
||||||
|
self.used = BitSet()
|
||||||
|
else:
|
||||||
|
used = state_dict.get("used", BitSet())
|
||||||
|
if len(used) == len(self.dataset):
|
||||||
|
self.exhausted = True
|
||||||
|
self.used = BitSet()
|
||||||
|
else:
|
||||||
|
self.exhausted = False
|
||||||
|
self.used = used
|
||||||
|
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
if len(self.used) == len(self.dataset):
|
||||||
|
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
|
||||||
|
else:
|
||||||
|
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
|
||||||
|
|
||||||
|
def update_state(self, indice):
|
||||||
|
self.used.update(indice)
|
||||||
|
|
||||||
|
|
||||||
|
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg_path: str,
|
||||||
|
cfg_json_str,
|
||||||
|
tokenizer,
|
||||||
|
max_length,
|
||||||
|
weight_by_size: bool = True,
|
||||||
|
nthreads=5,
|
||||||
|
prefetch_slice=100,
|
||||||
|
parallel_loading=False,
|
||||||
|
vdc_sampling=False,
|
||||||
|
update_weights_frequency=1,
|
||||||
|
seed=42,
|
||||||
|
):
|
||||||
|
super(MixedIndexedDataset, self).__init__()
|
||||||
|
self.set_seed(seed + bmt.rank())
|
||||||
|
self.weight_by_size = weight_by_size
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.eos_token_id = self.tokenizer.eos_id
|
||||||
|
self.bos_token_id = self.tokenizer.bos_id
|
||||||
|
self.path2transform = dict()
|
||||||
|
self.task_dict = OrderedDict()
|
||||||
|
self.nthreads = nthreads
|
||||||
|
self.prefetch_slice = prefetch_slice
|
||||||
|
# useful for indexing
|
||||||
|
self.tasks = []
|
||||||
|
self.names = []
|
||||||
|
# ending of iteration
|
||||||
|
self.remain = 0
|
||||||
|
self.max_length = max_length
|
||||||
|
self.vdc_sampling = vdc_sampling
|
||||||
|
if self.vdc_sampling:
|
||||||
|
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
|
||||||
|
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||||
|
|
||||||
|
self.update_weights_frequency = update_weights_frequency
|
||||||
|
|
||||||
|
self.path2transform = dict()
|
||||||
|
|
||||||
|
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||||
|
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||||
|
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||||
|
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||||
|
|
||||||
|
if parallel_loading:
|
||||||
|
self.parallel_load(cfgs, max_workers=None)
|
||||||
|
else:
|
||||||
|
self.sequential_load(cfgs)
|
||||||
|
|
||||||
|
self.weights = None
|
||||||
|
self.update_weights()
|
||||||
|
|
||||||
|
def set_seed(self, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
def load_task(self, cfg):
|
||||||
|
logger.info(f"Loading {cfg['path']}")
|
||||||
|
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||||
|
task = SegmentedDataset(
|
||||||
|
cfg,
|
||||||
|
self.tokenizer,
|
||||||
|
self.max_length,
|
||||||
|
transform_func=transform_func,
|
||||||
|
nthreads=self.nthreads,
|
||||||
|
prefetch_slice=self.prefetch_slice,
|
||||||
|
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||||
|
)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def sequential_load(self, cfgs):
|
||||||
|
self.cfgs = cfgs
|
||||||
|
for cfg in cfgs:
|
||||||
|
# python3.7 and later preserves insertion order to dictionary
|
||||||
|
logger.info(f"Loading {cfg['path']}")
|
||||||
|
|
||||||
|
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||||
|
task = SegmentedDataset(
|
||||||
|
cfg,
|
||||||
|
self.tokenizer,
|
||||||
|
self.max_length,
|
||||||
|
transform_func=transform_func,
|
||||||
|
nthreads=self.nthreads,
|
||||||
|
prefetch_slice=self.prefetch_slice,
|
||||||
|
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||||
|
)
|
||||||
|
self.task_dict[task.task_name] = task
|
||||||
|
self.tasks.append(task)
|
||||||
|
self.names.append(task.task_name)
|
||||||
|
self.remain += 1
|
||||||
|
self.weights = None
|
||||||
|
self.update_weights()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
missing_keys = []
|
||||||
|
for name, task in self.task_dict.items():
|
||||||
|
if name in state_dict:
|
||||||
|
task.load_state_dict(state_dict[name])
|
||||||
|
else:
|
||||||
|
missing_keys.append(name)
|
||||||
|
self.update_weights()
|
||||||
|
return missing_keys
|
||||||
|
|
||||||
|
def save_state_dict(self, path):
|
||||||
|
state_dict = {}
|
||||||
|
for name, task in self.task_dict.items():
|
||||||
|
_state_dict = task.state_dict()
|
||||||
|
if isinstance(_state_dict["used"], BitSet):
|
||||||
|
bitset = _state_dict["used"]
|
||||||
|
_file_name = bitset.save(path)
|
||||||
|
_state_dict["used"] = _file_name
|
||||||
|
state_dict[name] = _state_dict
|
||||||
|
else:
|
||||||
|
state_dict[name] = task.state_dict()
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
logger.info("Dataset state saved")
|
||||||
|
|
||||||
|
def update_states(self, task_ids, indice):
|
||||||
|
is_dict = isinstance(indice, dict)
|
||||||
|
uniq = torch.unique(task_ids)
|
||||||
|
for idx in uniq:
|
||||||
|
idx = idx.item()
|
||||||
|
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||||
|
self.tasks[idx].update_state(indexes)
|
||||||
|
|
||||||
|
def get_transform_func(self, module_name: str, transform_script_path):
|
||||||
|
if transform_script_path is None:
|
||||||
|
# allow null transform
|
||||||
|
return lambda data, num_incontext, rand: data
|
||||||
|
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||||
|
if transform_script_path not in self.path2transform:
|
||||||
|
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||||
|
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||||
|
if spec is None:
|
||||||
|
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
self.path2transform[transform_script_path] = {
|
||||||
|
"loader": loader,
|
||||||
|
"module": mod,
|
||||||
|
"last_mtime": 0,
|
||||||
|
}
|
||||||
|
transform_script_info = self.path2transform[transform_script_path]
|
||||||
|
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||||
|
if curr_mtime > transform_script_info["last_mtime"]:
|
||||||
|
transform_script_info["last_mtime"] = curr_mtime
|
||||||
|
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||||
|
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||||
|
if transform_func is None:
|
||||||
|
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||||
|
return transform_func
|
||||||
|
|
||||||
|
def update_weights(self):
|
||||||
|
task0 = self.tasks[0]
|
||||||
|
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||||
|
weights = []
|
||||||
|
for task in self.tasks:
|
||||||
|
if task.exhausted:
|
||||||
|
weights.append(0)
|
||||||
|
else:
|
||||||
|
if task.ave_tokens == -1:
|
||||||
|
weights.append(task.abs_weight / self.max_length)
|
||||||
|
else:
|
||||||
|
weights.append(task.abs_weight / task.ave_tokens)
|
||||||
|
weights = np.array(weights)
|
||||||
|
else:
|
||||||
|
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||||
|
if self.weight_by_size:
|
||||||
|
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||||
|
weights *= sizes
|
||||||
|
self.weights = weights / weights.sum()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for task in self.tasks:
|
||||||
|
task.iterate()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
step = 1
|
||||||
|
while True:
|
||||||
|
if self.remain == 0:
|
||||||
|
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||||
|
raise StopIteration
|
||||||
|
if self.vdc_sampling:
|
||||||
|
idx = next(self.vdc_gen)(self.weights)
|
||||||
|
else:
|
||||||
|
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||||
|
|
||||||
|
data = next(self.tasks[idx])
|
||||||
|
if step % self.update_weights_frequency == 0:
|
||||||
|
self.update_weights()
|
||||||
|
if data is None:
|
||||||
|
if self.tasks[idx].allow_repeat:
|
||||||
|
# _runtime_ave = self.tasks[idx].ave_tokens
|
||||||
|
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||||
|
# self.tasks[idx] = SegmentedDataset(
|
||||||
|
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||||
|
# )
|
||||||
|
# self.tasks[idx].ave_tokens_update(_runtime_ave)
|
||||||
|
self.tasks[idx].reset()
|
||||||
|
else:
|
||||||
|
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||||
|
self.tasks[idx].exhaust = True
|
||||||
|
self.remain -= 1
|
||||||
|
continue
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
task_id=idx,
|
||||||
|
input=data[0],
|
||||||
|
target=data[1],
|
||||||
|
index=data[2],
|
||||||
|
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||||
|
self.max_total_length = batch_size * max_length
|
||||||
|
self.batch_size = 1
|
||||||
|
# setting compact=True concats segments orignated from the same input
|
||||||
|
# into a long sequence. the relative order of segments should be preserved
|
||||||
|
# in mixed_dataset, e.g.,
|
||||||
|
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||||
|
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||||
|
self.compact = compact
|
||||||
|
|
||||||
|
self.total_length = 0
|
||||||
|
self.task2seqs = defaultdict(list)
|
||||||
|
self.mixed_dataset = mixed_dataset
|
||||||
|
self._max_length = max_length
|
||||||
|
self._pose_prob = pose_prob
|
||||||
|
self._pose_scaling_factor = pose_scaling_factor
|
||||||
|
if self._pose_prob > 0.0:
|
||||||
|
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||||
|
else:
|
||||||
|
self._scaled_max_length = max_length
|
||||||
|
|
||||||
|
def put(self, sample):
|
||||||
|
self.total_length += len(sample["target"])
|
||||||
|
task_id = sample["task_id"]
|
||||||
|
if self.compact and self.task2seqs[task_id]:
|
||||||
|
last = self.task2seqs[task_id][-1]
|
||||||
|
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||||
|
# concatenate sequantial segments for longer context modeling: why not?
|
||||||
|
last["input"].extend(sample["input"])
|
||||||
|
last["target"].extend(sample["target"])
|
||||||
|
return
|
||||||
|
self.task2seqs[task_id].append(sample)
|
||||||
|
|
||||||
|
def _pose_preprocess(
|
||||||
|
self,
|
||||||
|
input_ids: NDArray[np.int32],
|
||||||
|
) -> NDArray[np.int32]:
|
||||||
|
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||||
|
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||||
|
"""
|
||||||
|
len_chunk = min(len(input_ids), self._max_length)
|
||||||
|
len_input = len(input_ids)
|
||||||
|
# Chunk input randomly to fit max_length if needed
|
||||||
|
lt1 = 0
|
||||||
|
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||||
|
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||||
|
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||||
|
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||||
|
# Generate PoSE position ids
|
||||||
|
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||||
|
len_position_ids = len(position_ids)
|
||||||
|
lt = 0
|
||||||
|
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||||
|
position_ids[: rt1 - lt1] += lt
|
||||||
|
position_ids[rt1 - lt1 :] += rt
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
indexes = defaultdict(list)
|
||||||
|
lengths = []
|
||||||
|
|
||||||
|
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||||
|
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||||
|
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||||
|
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||||
|
|
||||||
|
span_begin = 0
|
||||||
|
for samples in self.task2seqs.values():
|
||||||
|
while samples:
|
||||||
|
sample = samples.pop()
|
||||||
|
span_end = span_begin + len(sample["input"])
|
||||||
|
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||||
|
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||||
|
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||||
|
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||||
|
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||||
|
else:
|
||||||
|
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||||
|
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||||
|
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||||
|
lengths.append(len(sample["target"]))
|
||||||
|
indexes[int(sample["task_id"])].append(sample["index"])
|
||||||
|
self.total_length -= len(sample["target"])
|
||||||
|
span_begin = span_end
|
||||||
|
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||||
|
dim=0,
|
||||||
|
).int()
|
||||||
|
batch = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"targets": targets,
|
||||||
|
"task_ids": task_ids,
|
||||||
|
"indexes": indexes,
|
||||||
|
# adhere to flash attention interface
|
||||||
|
"cu_seqlens": cu_seqlens,
|
||||||
|
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||||
|
"lengths": torch.tensor(sum(lengths)).int(),
|
||||||
|
"task_names": self.mixed_dataset.names,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def will_be_full(self, sample):
|
||||||
|
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for sample in self.mixed_dataset:
|
||||||
|
if self.will_be_full(sample):
|
||||||
|
yield self.pop()
|
||||||
|
self.put(sample)
|
||||||
|
|
||||||
|
|
||||||
|
class CudaPrefetcher(Iterable):
|
||||||
|
"""
|
||||||
|
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||||
|
self.loader = iter(loader)
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.stream = torch.cuda.Stream()
|
||||||
|
self.preload()
|
||||||
|
|
||||||
|
def preload(self):
|
||||||
|
try:
|
||||||
|
if self.tp_size > 1:
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
data = next(self.loader)
|
||||||
|
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||||
|
d = {}
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||||
|
for key in data.keys():
|
||||||
|
if isinstance(data[key], torch.Tensor):
|
||||||
|
np_cur_data = data[key].cpu().numpy()
|
||||||
|
bs = np_cur_data.tobytes()
|
||||||
|
fb.write(bs)
|
||||||
|
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||||
|
elif isinstance(data[key], np.ndarray):
|
||||||
|
bs = data[key].tobytes()
|
||||||
|
fb.write(bs)
|
||||||
|
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||||
|
else:
|
||||||
|
d[key] = data[key]
|
||||||
|
try:
|
||||||
|
_ = json.dumps(d)
|
||||||
|
except TypeError:
|
||||||
|
print(d)
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||||
|
json.dump(d, f)
|
||||||
|
bmt.synchronize()
|
||||||
|
if self.tp_rank != 0:
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||||
|
bs = fb.read()
|
||||||
|
offset = 0
|
||||||
|
for key in data.keys():
|
||||||
|
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||||
|
nw_offset = offset + data[key][2]
|
||||||
|
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||||
|
data[key][3:]
|
||||||
|
)
|
||||||
|
offset = nw_offset
|
||||||
|
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||||
|
nw_offset = offset + data[key][2]
|
||||||
|
data[key] = torch.from_numpy(
|
||||||
|
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||||
|
.reshape(data[key][3:])
|
||||||
|
.copy()
|
||||||
|
)
|
||||||
|
offset = nw_offset
|
||||||
|
self.data = data
|
||||||
|
else:
|
||||||
|
self.data = next(self.loader)
|
||||||
|
except StopIteration:
|
||||||
|
self.data = None
|
||||||
|
return
|
||||||
|
with torch.cuda.stream(self.stream):
|
||||||
|
for key in self.data.keys():
|
||||||
|
if isinstance(self.data[key], torch.Tensor):
|
||||||
|
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
torch.cuda.current_stream().wait_stream(self.stream)
|
||||||
|
for key in self.data.keys():
|
||||||
|
if isinstance(self.data[key], torch.Tensor):
|
||||||
|
self.data[key].record_stream(torch.cuda.current_stream())
|
||||||
|
data = copy.deepcopy(self.data)
|
||||||
|
self.preload()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
|
@ -0,0 +1,827 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||||
|
#
|
||||||
|
# @author: ouzebin <ouzebin@zhihu.com>
|
||||||
|
# @date: 2023/09/27
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import ctypes
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections import OrderedDict
|
||||||
|
from multiprocessing import Lock
|
||||||
|
from multiprocessing import Process
|
||||||
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import Iterator
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Set
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from tenacity import retry
|
||||||
|
from tenacity import stop_after_attempt
|
||||||
|
from tenacity import wait_fixed
|
||||||
|
from tenacity import wait_random
|
||||||
|
|
||||||
|
from fm9g.dataset import PrefetchDecodeDataset
|
||||||
|
from fm9g.utils.bitset import BitSet
|
||||||
|
from fm9g.utils.vdc_sampling import van_der_corput
|
||||||
|
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||||
|
|
||||||
|
from .flask_ps import app as flask_ps
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
IGNORE_TGT = -100
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||||
|
if cfg_json_str is not None:
|
||||||
|
cfgs = json.loads(cfg_json_str)
|
||||||
|
else:
|
||||||
|
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||||
|
cfgs = json.load(fin)
|
||||||
|
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||||
|
|
||||||
|
path_dict = None
|
||||||
|
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||||
|
try:
|
||||||
|
with open(platform_config_path, "r") as f:
|
||||||
|
platform_cfg = json.load(f)
|
||||||
|
path_dict = platform_cfg["dataset_map"]
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||||
|
except Exception as e:
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||||
|
|
||||||
|
task_name2dataset_name = dict()
|
||||||
|
for idx, cfg in enumerate(cfgs):
|
||||||
|
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||||
|
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||||
|
# to be delibrately annoying :)
|
||||||
|
if cfg["task_name"] in task_name2dataset_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||||
|
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||||
|
)
|
||||||
|
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||||
|
|
||||||
|
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||||
|
# if path_dict is not None:
|
||||||
|
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||||
|
|
||||||
|
# dealing with optional configs
|
||||||
|
if "weight" in cfg:
|
||||||
|
assert isinstance(cfg["weight"], (float, int))
|
||||||
|
else:
|
||||||
|
cfg["weight"] = 1.0
|
||||||
|
|
||||||
|
if "oversize_rule" in cfg:
|
||||||
|
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||||
|
else:
|
||||||
|
cfg["oversize_rule"] = "segment"
|
||||||
|
|
||||||
|
if "transforms" in cfg:
|
||||||
|
assert isinstance(cfg["transforms"], str)
|
||||||
|
# dealing with relative path
|
||||||
|
if not cfg["transforms"].startswith("/"):
|
||||||
|
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||||
|
if not cfg["transforms"]:
|
||||||
|
cfg["transforms"] = None
|
||||||
|
else:
|
||||||
|
cfg["transforms"] = None
|
||||||
|
|
||||||
|
if "incontext_weight" in cfg:
|
||||||
|
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||||
|
else:
|
||||||
|
cfg["incontext_weight"] = [1.0]
|
||||||
|
cfg["id"] = idx
|
||||||
|
# dataset and iterator will be built
|
||||||
|
return cfgs
|
||||||
|
|
||||||
|
|
||||||
|
def data2ids(data, tokenizer, max_length):
|
||||||
|
text = "\n".join(
|
||||||
|
[
|
||||||
|
data.get("title", "").strip(),
|
||||||
|
data.get("question", "").strip(),
|
||||||
|
data.get("answer", "").strip(),
|
||||||
|
data.get("abstract", "").strip(),
|
||||||
|
data.get("text", "").strip(),
|
||||||
|
data.get("code", "").strip(),
|
||||||
|
]
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
# suppress the annoying warning from tokenizer
|
||||||
|
ids = (
|
||||||
|
[tokenizer.bos_token_id]
|
||||||
|
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||||
|
+ [tokenizer.eos_token_id]
|
||||||
|
)
|
||||||
|
src_ids = ids[0:-1]
|
||||||
|
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||||
|
|
||||||
|
if len(src_ids) > max_length:
|
||||||
|
for st in range(0, len(src_ids), max_length):
|
||||||
|
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||||
|
else:
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
|
||||||
|
|
||||||
|
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||||
|
assert oversize_rule in ("drop", "head", "segment")
|
||||||
|
if data is None:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
if "output" not in data or not data["output"]:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
if "input" not in data:
|
||||||
|
data["input"] = ""
|
||||||
|
|
||||||
|
src_ids = [tokenizer.bos_token_id]
|
||||||
|
tgt_ids = []
|
||||||
|
has_input = False
|
||||||
|
is_segment_reenter = False
|
||||||
|
|
||||||
|
# Use incremental tokenization to avoid waiting for a long document
|
||||||
|
MAX_CHUNK_LENGTH = max_length * 10
|
||||||
|
for part in ("input", "output"):
|
||||||
|
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||||
|
while l < len(data[part]):
|
||||||
|
current_slice = data[part][l:r]
|
||||||
|
if not current_slice:
|
||||||
|
break
|
||||||
|
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||||
|
if part == "input":
|
||||||
|
if len(token_ids) > 0:
|
||||||
|
has_input = True
|
||||||
|
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||||
|
l = r
|
||||||
|
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||||
|
else:
|
||||||
|
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||||
|
if oversize_rule == "drop":
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
elif oversize_rule == "head":
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids[:-1])
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
return
|
||||||
|
elif oversize_rule == "segment":
|
||||||
|
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||||
|
if has_input: # is instruction data
|
||||||
|
if (
|
||||||
|
do_compact
|
||||||
|
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||||
|
and instruction_rest_space / len(src_ids) > 0.8
|
||||||
|
): # can be squeezed into max length
|
||||||
|
inputs_len = len(src_ids)
|
||||||
|
keep_len = instruction_rest_space // 2
|
||||||
|
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||||
|
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend(token_ids)
|
||||||
|
tgt_ids.append(tokenizer.eos_token_id)
|
||||||
|
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
else: # else use head rule
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids[:-1])
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
return
|
||||||
|
else: # normal segment
|
||||||
|
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||||
|
src_ids.extend(selected_token_ids)
|
||||||
|
tgt_ids.extend(selected_token_ids)
|
||||||
|
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||||
|
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||||
|
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||||
|
src_ids = src_ids[max_length:]
|
||||||
|
tgt_ids = tgt_ids[max_length:]
|
||||||
|
# sliding input str window
|
||||||
|
consumed_str = tokenizer.decode(selected_token_ids)
|
||||||
|
l += len(consumed_str)
|
||||||
|
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||||
|
is_segment_reenter = True
|
||||||
|
else:
|
||||||
|
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||||
|
not is_segment_reenter and len(token_ids) > 0
|
||||||
|
): # is segmented LM data
|
||||||
|
src_ids.extend(token_ids)
|
||||||
|
tgt_ids.extend(token_ids)
|
||||||
|
tgt_ids.append(tokenizer.eos_token_id)
|
||||||
|
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||||
|
yield src_ids, tgt_ids
|
||||||
|
else:
|
||||||
|
yield from ()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
tokenizer,
|
||||||
|
max_length=1024,
|
||||||
|
transform_func=None,
|
||||||
|
nthreads=1,
|
||||||
|
prefetch_slice=3,
|
||||||
|
slice_size=500,
|
||||||
|
do_compact=False,
|
||||||
|
is_local=False,
|
||||||
|
):
|
||||||
|
def get_full_qualified_name(func):
|
||||||
|
module_name = func.__module__
|
||||||
|
qual_name = func.__qualname__
|
||||||
|
return f"{module_name}.{qual_name}"
|
||||||
|
|
||||||
|
super(SegmentedDataset, self).__init__()
|
||||||
|
self.segment = functools.partial(
|
||||||
|
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||||
|
)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.max_length = max_length
|
||||||
|
self.nthreads = nthreads
|
||||||
|
self.transform_func = transform_func
|
||||||
|
self.prefetch_slice = prefetch_slice
|
||||||
|
self.slice_size = slice_size
|
||||||
|
self.abs_weight = cfg.get("abs_weight", None)
|
||||||
|
self.task_name = cfg["task_name"]
|
||||||
|
self.dataset_name = cfg["dataset_name"]
|
||||||
|
self.oversize_rule = cfg["oversize_rule"]
|
||||||
|
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", False))
|
||||||
|
self.exhausted = False
|
||||||
|
self.iterator = None
|
||||||
|
|
||||||
|
self.counter = 0
|
||||||
|
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||||
|
self.used = set()
|
||||||
|
self.is_local = is_local
|
||||||
|
self.port_offset = random.randint(1, 8000) + 1000
|
||||||
|
if is_local or bmt.rank() == 0:
|
||||||
|
addr = os.environ["MASTER_ADDR"]
|
||||||
|
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||||
|
_avg_tokens = cfg.get("ave_tokens_per_line", -1)
|
||||||
|
_avg_tokens = cfg.get("avg_tokens", _avg_tokens)
|
||||||
|
requests.post(f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={_avg_tokens}")
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||||
|
def set_avg_tokens(self, avg_tokens):
|
||||||
|
addr = os.environ["MASTER_ADDR"]
|
||||||
|
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||||
|
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={avg_tokens}"
|
||||||
|
response = requests.post(url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
self.reset_port_offset()
|
||||||
|
print(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
|
||||||
|
raise RuntimeError(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
|
||||||
|
|
||||||
|
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||||
|
def update_avg_tokens_by_ema(self, length):
|
||||||
|
addr = os.environ["MASTER_ADDR"]
|
||||||
|
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||||
|
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=update&length={length}"
|
||||||
|
response = requests.post(url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
self.reset_port_offset()
|
||||||
|
print(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
|
||||||
|
raise RuntimeError(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||||
|
def avg_tokens(self):
|
||||||
|
addr = os.environ["MASTER_ADDR"]
|
||||||
|
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||||
|
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}"
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
self.reset_port_offset()
|
||||||
|
print(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
|
||||||
|
raise RuntimeError(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
|
||||||
|
data = response.json()
|
||||||
|
return data["avg_tokens"]
|
||||||
|
|
||||||
|
def reset_port_offset(self):
|
||||||
|
self.port_offset = random.randint(1, 8000) + 1000
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.dataset.size()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.iterate()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.exhausted = False
|
||||||
|
if self.iterator is not None:
|
||||||
|
self.iterator.close()
|
||||||
|
self.iterator = None
|
||||||
|
self.used = BitSet()
|
||||||
|
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||||
|
|
||||||
|
def transform(self, data: dict) -> dict:
|
||||||
|
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||||
|
weight = weight / weight.sum()
|
||||||
|
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||||
|
return self.transform_func(data, num_incontext, random.Random())
|
||||||
|
|
||||||
|
def segment_iterate(self, sample_iter):
|
||||||
|
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||||
|
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||||
|
self.update_avg_tokens_by_ema(len(src_ids)) # 0 for input ids
|
||||||
|
yield src_ids, tgt_ids, index
|
||||||
|
|
||||||
|
def iterate(self):
|
||||||
|
# make the dataset itself an iterator
|
||||||
|
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||||
|
self.iterator = self.segment_iterate(sample_iter)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
# advance the task iterator
|
||||||
|
if self.iterator is None:
|
||||||
|
self.iterate()
|
||||||
|
try:
|
||||||
|
return next(self.iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.exhausted = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
if state_dict.get("exhausted", False):
|
||||||
|
self.exhausted = True
|
||||||
|
self.used = BitSet()
|
||||||
|
else:
|
||||||
|
used = state_dict.get("used", BitSet())
|
||||||
|
if len(used) == len(self.dataset):
|
||||||
|
self.exhausted = True
|
||||||
|
self.used = BitSet()
|
||||||
|
else:
|
||||||
|
self.exhausted = False
|
||||||
|
self.used = used
|
||||||
|
_avg_tokens = state_dict.get("ave_tokens", -1)
|
||||||
|
_avg_tokens = state_dict.get("avg_tokens", _avg_tokens)
|
||||||
|
if self.avg_tokens == -1 or self.avg_tokens < _avg_tokens:
|
||||||
|
self.set_avg_tokens(_avg_tokens)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
if len(self.used) == len(self.dataset):
|
||||||
|
return dict(exhausted=True, used=BitSet(), avg_tokens=self.avg_tokens)
|
||||||
|
else:
|
||||||
|
return dict(exhausted=False, used=self.used, avg_tokens=self.avg_tokens)
|
||||||
|
|
||||||
|
def update_state(self, indice):
|
||||||
|
self.used.update(indice)
|
||||||
|
|
||||||
|
|
||||||
|
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg_path: str,
|
||||||
|
cfg_json_str,
|
||||||
|
tokenizer,
|
||||||
|
max_length,
|
||||||
|
weight_by_size: bool = True,
|
||||||
|
nthreads=5,
|
||||||
|
prefetch_slice=100,
|
||||||
|
parallel_loading=False,
|
||||||
|
vdc_sampling=False,
|
||||||
|
update_weights_frequency=1,
|
||||||
|
seed=42,
|
||||||
|
):
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
port = int(os.environ["MASTER_PORT"]) + 2188
|
||||||
|
self.flask_ps_proc = mp.Process(target=flask_ps.run, kwargs={"host": "0.0.0.0", "port": port})
|
||||||
|
self.flask_ps_proc.start()
|
||||||
|
super(MixedIndexedDataset, self).__init__()
|
||||||
|
self.set_seed(seed + bmt.rank())
|
||||||
|
self.weight_by_size = weight_by_size
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
self.bos_token_id = self.tokenizer.bos_token_id
|
||||||
|
self.path2transform = dict()
|
||||||
|
self.task_dict = OrderedDict()
|
||||||
|
self.nthreads = nthreads
|
||||||
|
self.prefetch_slice = prefetch_slice
|
||||||
|
# useful for indexing
|
||||||
|
self.tasks = []
|
||||||
|
self.names = []
|
||||||
|
# ending of iteration
|
||||||
|
self.remain = 0
|
||||||
|
self.max_length = max_length
|
||||||
|
self.vdc_sampling = vdc_sampling
|
||||||
|
if self.vdc_sampling:
|
||||||
|
self._vdc_values = [van_der_corput(i) for i in range(100000)]
|
||||||
|
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||||
|
|
||||||
|
self.update_weights_frequency = update_weights_frequency
|
||||||
|
|
||||||
|
self.path2transform = dict()
|
||||||
|
|
||||||
|
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||||
|
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||||
|
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||||
|
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||||
|
|
||||||
|
if parallel_loading:
|
||||||
|
self.parallel_load(cfgs, max_workers=None)
|
||||||
|
else:
|
||||||
|
self.sequential_load(cfgs)
|
||||||
|
|
||||||
|
self.weights = None
|
||||||
|
self.update_weights()
|
||||||
|
|
||||||
|
def set_seed(self, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
def load_task(self, cfg):
|
||||||
|
logger.info(f"Loading {cfg['path']}")
|
||||||
|
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||||
|
task = SegmentedDataset(
|
||||||
|
cfg,
|
||||||
|
self.tokenizer,
|
||||||
|
self.max_length,
|
||||||
|
transform_func=transform_func,
|
||||||
|
nthreads=self.nthreads,
|
||||||
|
prefetch_slice=self.prefetch_slice,
|
||||||
|
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||||
|
)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def sequential_load(self, cfgs):
|
||||||
|
self.cfgs = cfgs
|
||||||
|
for cfg in cfgs:
|
||||||
|
# python3.7 and later preserves insertion order to dictionary
|
||||||
|
logger.info(f"Loading {cfg['path']}")
|
||||||
|
|
||||||
|
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||||
|
task = SegmentedDataset(
|
||||||
|
cfg,
|
||||||
|
self.tokenizer,
|
||||||
|
self.max_length,
|
||||||
|
transform_func=transform_func,
|
||||||
|
nthreads=self.nthreads,
|
||||||
|
prefetch_slice=self.prefetch_slice,
|
||||||
|
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||||
|
)
|
||||||
|
self.task_dict[task.task_name] = task
|
||||||
|
self.tasks.append(task)
|
||||||
|
self.names.append(task.task_name)
|
||||||
|
self.remain += 1
|
||||||
|
self.weights = None
|
||||||
|
self.update_weights()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
missing_keys = []
|
||||||
|
for name, task in self.task_dict.items():
|
||||||
|
if name in state_dict:
|
||||||
|
task.load_state_dict(state_dict[name])
|
||||||
|
else:
|
||||||
|
missing_keys.append(name)
|
||||||
|
self.update_weights()
|
||||||
|
return missing_keys
|
||||||
|
|
||||||
|
def save_state_dict(self, path):
|
||||||
|
state_dict = {}
|
||||||
|
for name, task in self.task_dict.items():
|
||||||
|
_state_dict = task.state_dict()
|
||||||
|
if isinstance(_state_dict["used"], BitSet):
|
||||||
|
bitset = _state_dict["used"]
|
||||||
|
_file_name = bitset.save(path)
|
||||||
|
_state_dict["used"] = _file_name
|
||||||
|
state_dict[name] = _state_dict
|
||||||
|
else:
|
||||||
|
state_dict[name] = task.state_dict()
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
logger.info("Dataset state saved")
|
||||||
|
|
||||||
|
def update_states(self, task_ids, indice):
|
||||||
|
is_dict = isinstance(indice, dict)
|
||||||
|
uniq = torch.unique(task_ids)
|
||||||
|
for idx in uniq:
|
||||||
|
idx = idx.item()
|
||||||
|
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||||
|
self.tasks[idx].update_state(indexes)
|
||||||
|
|
||||||
|
def get_transform_func(self, module_name: str, transform_script_path):
|
||||||
|
if transform_script_path is None:
|
||||||
|
# allow null transform
|
||||||
|
return lambda data, num_incontext, rand: data
|
||||||
|
if "/" in module_name:
|
||||||
|
module_name = "fm9g_live.transforms.{}".format(module_name.split("/")[-1])
|
||||||
|
else:
|
||||||
|
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||||
|
if transform_script_path not in self.path2transform:
|
||||||
|
# loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||||
|
# spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, transform_script_path)
|
||||||
|
if spec is None:
|
||||||
|
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
self.path2transform[transform_script_path] = {
|
||||||
|
"module": mod,
|
||||||
|
"last_mtime": 0,
|
||||||
|
}
|
||||||
|
transform_script_info = self.path2transform[transform_script_path]
|
||||||
|
curr_mtime = float(os.path.getmtime(transform_script_path))
|
||||||
|
if curr_mtime > transform_script_info["last_mtime"]:
|
||||||
|
transform_script_info["last_mtime"] = curr_mtime
|
||||||
|
# load module
|
||||||
|
spec.loader.exec_module(transform_script_info["module"])
|
||||||
|
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||||
|
if transform_func is None:
|
||||||
|
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||||
|
return transform_func
|
||||||
|
|
||||||
|
def update_weights(self):
|
||||||
|
task0 = self.tasks[0]
|
||||||
|
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||||
|
weights = []
|
||||||
|
for task in self.tasks:
|
||||||
|
if task.exhausted:
|
||||||
|
weights.append(0)
|
||||||
|
else:
|
||||||
|
if task.avg_tokens == -1:
|
||||||
|
weights.append(task.abs_weight / self.max_length)
|
||||||
|
else:
|
||||||
|
weights.append(task.abs_weight / task.avg_tokens)
|
||||||
|
weights = np.array(weights)
|
||||||
|
else:
|
||||||
|
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||||
|
if self.weight_by_size:
|
||||||
|
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||||
|
weights *= sizes
|
||||||
|
self.weights = weights / weights.sum()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for task in self.tasks:
|
||||||
|
task.iterate()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
step = 1
|
||||||
|
while True:
|
||||||
|
if self.remain == 0:
|
||||||
|
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||||
|
raise StopIteration
|
||||||
|
if self.vdc_sampling:
|
||||||
|
idx = next(self.vdc_gen)(self.weights)
|
||||||
|
else:
|
||||||
|
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||||
|
|
||||||
|
data = next(self.tasks[idx])
|
||||||
|
if step % self.update_weights_frequency == 0:
|
||||||
|
self.update_weights()
|
||||||
|
if data is None:
|
||||||
|
if self.tasks[idx].allow_repeat:
|
||||||
|
# _runtime_ave = self.tasks[idx].avg_tokens
|
||||||
|
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||||
|
# self.tasks[idx] = SegmentedDataset(
|
||||||
|
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||||
|
# )
|
||||||
|
# self.tasks[idx].avg_tokens_update(_runtime_ave)
|
||||||
|
self.tasks[idx].reset()
|
||||||
|
else:
|
||||||
|
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||||
|
self.tasks[idx].exhaust = True
|
||||||
|
self.remain -= 1
|
||||||
|
continue
|
||||||
|
step += 1
|
||||||
|
return dict(
|
||||||
|
task_id=idx,
|
||||||
|
input=data[0],
|
||||||
|
target=data[1],
|
||||||
|
index=data[2],
|
||||||
|
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||||
|
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||||
|
self.max_total_length = batch_size * max_length
|
||||||
|
self.batch_size = 1
|
||||||
|
# setting compact=True concats segments orignated from the same input
|
||||||
|
# into a long sequence. the relative order of segments should be preserved
|
||||||
|
# in mixed_dataset, e.g.,
|
||||||
|
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||||
|
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||||
|
self.compact = compact
|
||||||
|
|
||||||
|
self.total_length = 0
|
||||||
|
self.task2seqs = defaultdict(list)
|
||||||
|
self.mixed_dataset = mixed_dataset
|
||||||
|
self._max_length = max_length
|
||||||
|
self._pose_prob = pose_prob
|
||||||
|
self._pose_scaling_factor = pose_scaling_factor
|
||||||
|
if self._pose_prob > 0.0:
|
||||||
|
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||||
|
else:
|
||||||
|
self._scaled_max_length = max_length
|
||||||
|
|
||||||
|
def put(self, sample):
|
||||||
|
self.total_length += len(sample["target"])
|
||||||
|
task_id = sample["task_id"]
|
||||||
|
if self.compact and self.task2seqs[task_id]:
|
||||||
|
last = self.task2seqs[task_id][-1]
|
||||||
|
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||||
|
# concatenate sequantial segments for longer context modeling: why not?
|
||||||
|
last["input"].extend(sample["input"])
|
||||||
|
last["target"].extend(sample["target"])
|
||||||
|
return
|
||||||
|
self.task2seqs[task_id].append(sample)
|
||||||
|
|
||||||
|
def _pose_preprocess(
|
||||||
|
self,
|
||||||
|
input_ids: NDArray[np.int32],
|
||||||
|
) -> NDArray[np.int32]:
|
||||||
|
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||||
|
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||||
|
"""
|
||||||
|
len_chunk = min(len(input_ids), self._max_length)
|
||||||
|
len_input = len(input_ids)
|
||||||
|
# Chunk input randomly to fit max_length if needed
|
||||||
|
lt1 = 0
|
||||||
|
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||||
|
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||||
|
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||||
|
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||||
|
# Generate PoSE position ids
|
||||||
|
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||||
|
len_position_ids = len(position_ids)
|
||||||
|
lt = 0
|
||||||
|
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||||
|
position_ids[: rt1 - lt1] += lt
|
||||||
|
position_ids[rt1 - lt1 :] += rt
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
indexes = defaultdict(list)
|
||||||
|
lengths = []
|
||||||
|
|
||||||
|
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||||
|
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||||
|
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||||
|
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||||
|
|
||||||
|
span_begin = 0
|
||||||
|
for samples in self.task2seqs.values():
|
||||||
|
while samples:
|
||||||
|
sample = samples.pop()
|
||||||
|
span_end = span_begin + len(sample["input"])
|
||||||
|
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||||
|
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||||
|
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||||
|
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||||
|
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||||
|
else:
|
||||||
|
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||||
|
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||||
|
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||||
|
lengths.append(len(sample["target"]))
|
||||||
|
indexes[int(sample["task_id"])].append(sample["index"])
|
||||||
|
self.total_length -= len(sample["target"])
|
||||||
|
span_begin = span_end
|
||||||
|
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||||
|
dim=0,
|
||||||
|
).int()
|
||||||
|
batch = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"targets": targets,
|
||||||
|
"task_ids": task_ids,
|
||||||
|
"indexes": indexes,
|
||||||
|
# adhere to flash attention interface
|
||||||
|
"cu_seqlens": cu_seqlens,
|
||||||
|
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||||
|
"lengths": torch.tensor(sum(lengths)).int(),
|
||||||
|
"task_names": self.mixed_dataset.names,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def will_be_full(self, sample):
|
||||||
|
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for sample in self.mixed_dataset:
|
||||||
|
if self.will_be_full(sample):
|
||||||
|
yield self.pop()
|
||||||
|
self.put(sample)
|
||||||
|
|
||||||
|
|
||||||
|
class CudaPrefetcher(Iterable):
|
||||||
|
"""
|
||||||
|
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||||
|
self.loader = iter(loader)
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.stream = torch.cuda.Stream()
|
||||||
|
self.preload()
|
||||||
|
|
||||||
|
def preload(self):
|
||||||
|
try:
|
||||||
|
if self.tp_size > 1:
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
data = next(self.loader)
|
||||||
|
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||||
|
d = {}
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||||
|
for key in data.keys():
|
||||||
|
if isinstance(data[key], torch.Tensor):
|
||||||
|
np_cur_data = data[key].cpu().numpy()
|
||||||
|
bs = np_cur_data.tobytes()
|
||||||
|
fb.write(bs)
|
||||||
|
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||||
|
elif isinstance(data[key], np.ndarray):
|
||||||
|
bs = data[key].tobytes()
|
||||||
|
fb.write(bs)
|
||||||
|
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||||
|
else:
|
||||||
|
d[key] = data[key]
|
||||||
|
try:
|
||||||
|
_ = json.dumps(d)
|
||||||
|
except TypeError:
|
||||||
|
print(d)
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||||
|
json.dump(d, f)
|
||||||
|
bmt.synchronize()
|
||||||
|
if self.tp_rank != 0:
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||||
|
bs = fb.read()
|
||||||
|
offset = 0
|
||||||
|
for key in data.keys():
|
||||||
|
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||||
|
nw_offset = offset + data[key][2]
|
||||||
|
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||||
|
data[key][3:]
|
||||||
|
)
|
||||||
|
offset = nw_offset
|
||||||
|
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||||
|
nw_offset = offset + data[key][2]
|
||||||
|
data[key] = torch.from_numpy(
|
||||||
|
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||||
|
.reshape(data[key][3:])
|
||||||
|
.copy()
|
||||||
|
)
|
||||||
|
offset = nw_offset
|
||||||
|
self.data = data
|
||||||
|
else:
|
||||||
|
self.data = next(self.loader)
|
||||||
|
except StopIteration:
|
||||||
|
self.data = None
|
||||||
|
return
|
||||||
|
with torch.cuda.stream(self.stream):
|
||||||
|
for key in self.data.keys():
|
||||||
|
if isinstance(self.data[key], torch.Tensor):
|
||||||
|
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
torch.cuda.current_stream().wait_stream(self.stream)
|
||||||
|
data = copy.deepcopy(self.data)
|
||||||
|
self.preload()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
|
@ -12,3 +12,4 @@ from .position_embedding import RotaryEmbedding
|
||||||
from .position_embedding import RotaryEmbeddingESM
|
from .position_embedding import RotaryEmbeddingESM
|
||||||
from .position_embedding import SegmentPositionEmbedding
|
from .position_embedding import SegmentPositionEmbedding
|
||||||
from .transformer import Encoder
|
from .transformer import Encoder
|
||||||
|
#from _attention_pp_sp import OpAttnPipeSP
|
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import bmtrain as bmt
|
||||||
|
|
||||||
|
def _linear_backward(grad_output, x, weight, bias):
|
||||||
|
grad_x = grad_weight = grad_bias = None
|
||||||
|
if x.requires_grad:
|
||||||
|
grad_x = grad_output.matmul(weight)
|
||||||
|
if weight.requires_grad:
|
||||||
|
grad_weight = grad_output.reshape(-1,
|
||||||
|
grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
|
||||||
|
if bias is not None and bias.requires_grad:
|
||||||
|
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
|
||||||
|
return grad_x, grad_weight, grad_bias
|
||||||
|
|
||||||
|
class OpAttnPipeSP(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, q_w, k_w, v_w, q_b, w_b, v_b, x, cache_kv, cache_kv_inp, cu_seqlens_q, cu_seqlens_k, max_seqlen):
|
||||||
|
ctx.save_for_backward(x, q_w, k_w, v_w, q_b, w_b, v_b)
|
||||||
|
if cache_kv.numel() = 0:
|
||||||
|
q = F.linear(x, q_w, q_b)
|
||||||
|
k = F.linear(x, k_w, w_b)
|
||||||
|
v = F.linear(x, v_w, v_b)
|
||||||
|
else:
|
||||||
|
q = F.linear(x, q_w, q_b)
|
||||||
|
k = F.linear(x, k_w, w_b)
|
||||||
|
v = F.linear(x, v_w, v_b)
|
||||||
|
k = torch.cat([cache_kv[0], k], dim=1)
|
||||||
|
v = torch.cat([cache_kv[1], v], dim=1)
|
||||||
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||||
|
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen, max_seqlen, 0, causal=True, window_size=(-1,-1), alibi_slopes=None, deterministic=False, return_attn_probs=False
|
||||||
|
)
|
||||||
|
ctx.save_for_backward(
|
||||||
|
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
||||||
|
)
|
||||||
|
ctx.max_seqlen_q = max_seqlen
|
||||||
|
ctx.max_seqlen_k = max_seqlen
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return F.linear(x, weight, bias)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||||
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||||
|
_flash_attn_varlen_backward(
|
||||||
|
dout,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
softmax_lse,
|
||||||
|
dq,
|
||||||
|
dk,
|
||||||
|
dv,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
ctx.max_seqlen_q,
|
||||||
|
ctx.max_seqlen_k,
|
||||||
|
ctx.dropout_p,
|
||||||
|
ctx.softmax_scale,
|
||||||
|
False,
|
||||||
|
(-1,-1),
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
rng_state=rng_state,
|
||||||
|
)
|
||||||
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||||
|
dk = dk[..., : dout.shape[-1]]
|
||||||
|
dv = dv[..., : dout.shape[-1]]
|
||||||
|
|
||||||
|
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
|
||||||
|
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
|
||||||
|
d_xk, d_wk, d_bk = _linear_backward(dk, x, k_w, k_b)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
|
|
@ -13,13 +13,13 @@ from einops import rearrange
|
||||||
from .linear import ColumnParallelLinear
|
from .linear import ColumnParallelLinear
|
||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .position_embedding import apply_chatglm_rotary_pos_emb
|
from .position_embedding import apply_chatglm_rotary_pos_emb
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
|
|
||||||
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
except:
|
#try:
|
||||||
flash_attn_varlen_func = None
|
# from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
|
||||||
|
# from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
|
||||||
|
# from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
#except:
|
||||||
|
# flash_attn_varlen_func = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.bert_padding import pad_input
|
from flash_attn.bert_padding import pad_input
|
||||||
|
@ -54,6 +54,8 @@ class OpFlash(torch.autograd.Function):
|
||||||
dropout_p,
|
dropout_p,
|
||||||
ctx.softmax_scale,
|
ctx.softmax_scale,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
alibi_slopes=None,
|
||||||
return_softmax=False,
|
return_softmax=False,
|
||||||
)
|
)
|
||||||
if record:
|
if record:
|
||||||
|
@ -85,6 +87,9 @@ class OpFlash(torch.autograd.Function):
|
||||||
ctx.dropout_p,
|
ctx.dropout_p,
|
||||||
ctx.softmax_scale,
|
ctx.softmax_scale,
|
||||||
ctx.causal,
|
ctx.causal,
|
||||||
|
(-1,-1),
|
||||||
|
None,
|
||||||
|
False,
|
||||||
rng_state=rng_state,
|
rng_state=rng_state,
|
||||||
)
|
)
|
||||||
return None, None, dq, dk, dv, None, None, None, None
|
return None, None, dq, dk, dv, None, None, None, None
|
||||||
|
@ -294,12 +299,28 @@ class Attention(bmt.DistributedModule):
|
||||||
h_q, h_k = position_bias(
|
h_q, h_k = position_bias(
|
||||||
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
|
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
|
||||||
)
|
)
|
||||||
# score = flash_attn_varlen_func(
|
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
||||||
# h_q, h_k, h_v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.dropout_p, causal=True
|
print("h_q: ", h_q)
|
||||||
# )
|
print("cu_seqlens: ", cu_seqlens)
|
||||||
score = OpFlash.apply(
|
print("max_seqlen: ", max_seqlen)
|
||||||
self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
|
score = flash_attn_varlen_func(
|
||||||
|
h_q,
|
||||||
|
h_k,
|
||||||
|
h_v,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
self.dropout_p,
|
||||||
|
causal=True,
|
||||||
|
deterministic=True,
|
||||||
)
|
)
|
||||||
|
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
||||||
|
# Rongqiao change
|
||||||
|
#score = OpFlash.apply(
|
||||||
|
# self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
|
||||||
|
#)
|
||||||
|
|
||||||
|
|
||||||
score = score.view(batch_size, len_q, -1)
|
score = score.view(batch_size, len_q, -1)
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class EMAValue(object):
|
||||||
|
def __init__(self, init_value: Optional[float] = None, decay_factor: float = 0.999) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._value = init_value
|
||||||
|
self._decay_factor = decay_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> Optional[float]:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def update(self, value: float) -> None:
|
||||||
|
if self._value is None:
|
||||||
|
self._value = value
|
||||||
|
else:
|
||||||
|
self._value = self._decay_factor * self._value + (1 - self._decay_factor) * value
|
||||||
|
|
||||||
|
def update_with_return(self, value: float) -> Optional[float]:
|
||||||
|
self.update(value)
|
||||||
|
return self._value
|
|
@ -0,0 +1 @@
|
||||||
|
from .fm9g import FM9GTokenizer
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue