forked from jiuyuan/CPM-9G-8B
add fm9g 2b and 8b models
This commit is contained in:
parent
03c55e1fee
commit
cfd2fca57c
|
@ -1,16 +0,0 @@
|
|||
{
|
||||
"vocab_size": 119696,
|
||||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"use_flash_attn": true,
|
||||
"flash_attn_mask_shape": "2d",
|
||||
"dim_model": 4096,
|
||||
"dim_ff": 12288,
|
||||
"dim_head": 128,
|
||||
"num_heads": 32,
|
||||
"num_kv_heads": 32,
|
||||
"num_layers": 48,
|
||||
"activate_fn": "silu",
|
||||
"scale": false
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
{
|
||||
"vocab_size": 119696,
|
||||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"dim_model": 4096,
|
||||
"dim_ff": 11008,
|
||||
"dim_head": 128,
|
||||
"num_heads": 32,
|
||||
"num_kv_heads": 32,
|
||||
"num_layers": 32,
|
||||
"activate_fn": "silu",
|
||||
"scale": false
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,12 +0,0 @@
|
|||
[
|
||||
{
|
||||
"dataset_name": "wikipedia",
|
||||
"task_name": "wikipedia",
|
||||
"weight": 1.0,
|
||||
"path": "path/to/data",
|
||||
"incontext_weight": [
|
||||
1.0
|
||||
],
|
||||
"transforms": "/home/USERNAME/cpm9g/apps/cpm9g/config/datasets/wikipedia/script_cpmc.py"
|
||||
}
|
||||
]
|
|
@ -1,9 +0,0 @@
|
|||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
return {"input": "", "output": data["text"]}
|
|
@ -1,485 +0,0 @@
|
|||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
||||
from cpm.arguments import get_args
|
||||
from cpm.cpm9g.models import CPM9G
|
||||
from cpm.cpm9g.models import CPM9GConfig
|
||||
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
||||
from cpm.cpm9g.training_tasks import MixedDataset
|
||||
from cpm.utils import allgather_objects
|
||||
from cpm.utils import exporter
|
||||
from cpm.utils import logger
|
||||
from cpm.utils import LogManager
|
||||
|
||||
|
||||
def get_tokenizer(args):
|
||||
tokenizer = CPM9GTokenizer(path=args.vocab)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(args):
|
||||
config = CPM9GConfig.from_json_file(args.model_config)
|
||||
config.tp = 1 if args.tp != 1 else 0
|
||||
if args.flash == "none":
|
||||
config.use_flash_attn = False
|
||||
else:
|
||||
config.use_flash_attn = True
|
||||
if args.flash == "1d":
|
||||
config.flash_attn_mask_shape = "1d"
|
||||
else:
|
||||
config.flash_attn_mask_shape = "2d"
|
||||
if args.flash == "triton":
|
||||
config.flash_impl = "triton"
|
||||
elif args.flash == "cuda":
|
||||
config.flash_impl = "cuda"
|
||||
model = CPM9G(config)
|
||||
if args.load is not None:
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||
bmt.load(model, args.load)
|
||||
else:
|
||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||
bmt.init_parameters(model)
|
||||
return model
|
||||
|
||||
|
||||
def get_optimizer(args, model):
|
||||
for name, para in model.named_parameters():
|
||||
# if not ('input_embedding' in name or 'lm_head' in name):
|
||||
# para.requires_grad_(False)
|
||||
bmt.print_rank(name, para.requires_grad)
|
||||
if args.offload:
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
||||
)
|
||||
else:
|
||||
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
if args.load is not None and args.load_grad:
|
||||
start = time.time()
|
||||
print(
|
||||
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
|
||||
)
|
||||
if (
|
||||
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
|
||||
== bmt.world_size()
|
||||
):
|
||||
file_name = os.path.join(
|
||||
args.save,
|
||||
args.save_name + "-{}.rank-{}.opt".format(args.start_step, bmt.rank()),
|
||||
)
|
||||
print(file_name)
|
||||
if os.path.exists(file_name):
|
||||
print("start to load grad ckpt {}".format(file_name))
|
||||
states = torch.load(file_name)
|
||||
optimizer.load_state_dict(states)
|
||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||
return optimizer
|
||||
|
||||
|
||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||
r"""
|
||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||
"""
|
||||
|
||||
def get_lr_warmup(self, num_iter) -> float:
|
||||
return self.start_lr * num_iter / self.warmup_iter
|
||||
|
||||
def get_lr_decay(self, num_iter) -> float:
|
||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
if args.lr_decay_iters is None:
|
||||
args.lr_decay_iters = args.train_iters
|
||||
# lr_scheduler = bmt.lr_scheduler.Noam(
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=args.warmup_iters,
|
||||
end_iter=args.lr_decay_iters,
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
|
||||
start = time.time()
|
||||
model = get_model(args)
|
||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
tokenizer = get_tokenizer(args)
|
||||
bmt.synchronize()
|
||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
optimizer = get_optimizer(args, model)
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def initialize():
|
||||
args = get_args(pretrain=True)
|
||||
bmt.init_distributed(seed=args.seed, zero_level=3)
|
||||
if args.save is not None:
|
||||
os.makedirs(args.save, exist_ok=True)
|
||||
if args.load is not None:
|
||||
if args.start_step == 0:
|
||||
args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
||||
return args
|
||||
|
||||
|
||||
def see_memory(detail=False):
|
||||
if detail:
|
||||
res = torch.cuda.memory_summary()
|
||||
else:
|
||||
res = (
|
||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return res
|
||||
|
||||
|
||||
def add_mem_time(info, mem_usage, tim_usage):
|
||||
torch.cuda.synchronize()
|
||||
bmt.synchronize()
|
||||
mem_usage[info] = see_memory()
|
||||
tim_usage[info] = time.time()
|
||||
return mem_usage, tim_usage
|
||||
|
||||
|
||||
class LossSpikeDetector:
|
||||
def __init__(self, log_path: str) -> None:
|
||||
self._last_loss: Dict[str, float] = {}
|
||||
self._last_data: List[Any] = [None]
|
||||
self._log_path = log_path
|
||||
|
||||
def update_data(self, data: Any):
|
||||
self._last_data.append(data)
|
||||
if len(self._last_data) > 2:
|
||||
self._last_data = self._last_data[-2:]
|
||||
|
||||
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
||||
loss_spike_result = []
|
||||
for task, loss in loss_map.items():
|
||||
if task in self._last_loss:
|
||||
if loss > self._last_loss[task] * 3:
|
||||
# loss spike!
|
||||
loss_spike_result.append(
|
||||
{
|
||||
"prev": self._last_loss[task],
|
||||
"curr": loss,
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
self._last_loss[task] = float(loss)
|
||||
if len(loss_spike_result) > 0:
|
||||
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
||||
|
||||
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
||||
while True:
|
||||
try:
|
||||
with open(self._log_path, "a", encoding="utf-8") as fp:
|
||||
fp.write("=" * 20)
|
||||
fp.write("\nloss spike at {}\n".format(iteration))
|
||||
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
||||
fp.write("data: \n")
|
||||
for d in data:
|
||||
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
||||
fp.write("\n\n")
|
||||
break
|
||||
except Exception as e:
|
||||
print("cannot output log to the file {}", self._log_path)
|
||||
|
||||
|
||||
def pretrain(
|
||||
args,
|
||||
tokenizer: CPM9GTokenizer,
|
||||
model: CPM9G,
|
||||
optimizer,
|
||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||
):
|
||||
average_time = bmt.utils.AverageRecorder()
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
||||
optim_manager = bmt.optim.OptimManager(
|
||||
loss_scale=None if args.bf16 else args.loss_scale,
|
||||
loss_scale_steps=args.loss_scale_steps,
|
||||
loss_scale_factor=2,
|
||||
max_loss_scale=args.max_loss_scale,
|
||||
min_loss_scale=args.min_loss_scale,
|
||||
)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
|
||||
start_step = args.start_step
|
||||
lsd = LossSpikeDetector("./log/debug/spile.%d.log" % bmt.rank())
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
import distutils.version # noqa: F401
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
if not os.path.exists(args.tensorboard):
|
||||
os.makedirs(args.tensorboard)
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
|
||||
if args.log_dir is not None and bmt.rank() == 0:
|
||||
log_mgr = LogManager(args.log_dir)
|
||||
|
||||
global_token_pass = 0.0
|
||||
global_world_size = bmt.world_size()
|
||||
dataloader = MixedDataset(args.dataset, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"))
|
||||
|
||||
if args.load is not None:
|
||||
dataset_states_path = args.load.replace(".pt", ".data")
|
||||
if os.path.exists(dataset_states_path):
|
||||
start = time.time()
|
||||
bmt.print_rank("start to load data ckpt")
|
||||
dataset_states = torch.load(dataset_states_path)
|
||||
logger.info("load data ckpt in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
missing = dataloader.load_state_dict(dataset_states)
|
||||
logger.info("load state dict in {:.2f}s".format(time.time() - start))
|
||||
if len(missing) > 0:
|
||||
bmt.print_rank("Missing keys when loading dataset states: ", missing)
|
||||
else:
|
||||
bmt.print_rank("cannot find data ckpt {}".format(dataset_states_path))
|
||||
|
||||
dataloader.start()
|
||||
bmt.print_rank("finish dataset start")
|
||||
try:
|
||||
total = 0
|
||||
hash = {}
|
||||
for iteration, data in enumerate(dataloader):
|
||||
iteration = iteration + start_step + 1
|
||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
||||
targets = torch.from_numpy(data["target"]).cuda().to(torch.int32)
|
||||
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
||||
task_names = data["task_names"]
|
||||
lsd.update_data(data["raw_data"])
|
||||
if args.flash == "cuda":
|
||||
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
||||
max_seqlen = data["max_seqlen"]
|
||||
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
||||
else:
|
||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||
input_context = torch.zeros_like(input_ids).cuda().bool()
|
||||
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
||||
|
||||
# ===========
|
||||
optim_manager.zero_grad()
|
||||
# torch.cuda.empty_cache()
|
||||
mem_usage = {}
|
||||
tim_usage = {}
|
||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||
|
||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||
# ===========
|
||||
if args.flash == "cuda":
|
||||
logits, _ = model(
|
||||
input_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
else:
|
||||
logits, _ = model(
|
||||
input_ids,
|
||||
input_length,
|
||||
input_context,
|
||||
input_span,
|
||||
)
|
||||
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
||||
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
global_loss = bmt.sum_loss(loss).item()
|
||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||
|
||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||
# ===========
|
||||
optim_manager.backward(loss)
|
||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||
|
||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||
# ===========
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||
|
||||
# ==========
|
||||
iter_time = tim_usage["optim"] - tim_usage["init"]
|
||||
average_time.record(iter_time)
|
||||
|
||||
with torch.no_grad():
|
||||
task_num = len(task_names)
|
||||
targets_tmp = targets.expand(task_num, -1, -1)
|
||||
task = torch.arange(task_num, dtype=torch.int32, device="cuda")[:, None, None]
|
||||
targets_tmp = torch.where(
|
||||
task_ids == task,
|
||||
targets_tmp,
|
||||
torch.scalar_tensor(-100, dtype=torch.int32, device="cuda"),
|
||||
)
|
||||
|
||||
task_loss_map: Dict[str, float] = {}
|
||||
task_loss_tot: Dict[str, float] = {}
|
||||
for i in range(task_num):
|
||||
task_loss_map[task_names[i]] = loss_func(
|
||||
logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1)
|
||||
).item()
|
||||
task_loss_tot[task_names[i]] = (targets_tmp[i, :].view(-1) >= 0).sum().float().item()
|
||||
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
||||
gatherd_task_loss_tot: List[Dict[str, float]] = allgather_objects(task_loss_tot)
|
||||
|
||||
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
||||
global_task_loss_tot: Dict[str, Union[List[float], float]] = {}
|
||||
|
||||
for idx, local_task_loss_map in enumerate(gatherd_task_loss_map):
|
||||
for task_name, task_loss in local_task_loss_map.items():
|
||||
if task_name not in global_task_loss_map:
|
||||
global_task_loss_map[task_name] = []
|
||||
global_task_loss_map[task_name].append(task_loss)
|
||||
for task_name, task_tot in gatherd_task_loss_tot[idx].items():
|
||||
if task_name not in global_task_loss_tot:
|
||||
global_task_loss_tot[task_name] = []
|
||||
global_task_loss_tot[task_name].append(task_tot)
|
||||
|
||||
task_loss_map = {}
|
||||
for task_name in sorted(list(global_task_loss_map.keys())):
|
||||
avg_loss = 0.0
|
||||
sum_token = sum(global_task_loss_tot[task_name])
|
||||
for loss, token in zip(global_task_loss_map[task_name], global_task_loss_tot[task_name]):
|
||||
avg_loss += loss * token / sum_token
|
||||
task_loss_map[task_name] = avg_loss
|
||||
|
||||
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
||||
avg_time = average_time.value
|
||||
lsd.update_loss(iteration, task_loss_map)
|
||||
|
||||
for task_id in data["task_ids"]:
|
||||
for task in task_id:
|
||||
if task != -1:
|
||||
if not data["task_names"][task] in hash:
|
||||
hash[data["task_names"][task]] = 0
|
||||
hash[data["task_names"][task]] += 1.0
|
||||
total += 1.0
|
||||
|
||||
gathered_hash = allgather_objects(hash)
|
||||
sum_total = sum(allgather_objects(total))
|
||||
|
||||
final_hash = defaultdict(int)
|
||||
for local_hash in gathered_hash:
|
||||
for task, num in local_hash.items():
|
||||
final_hash[task] += num
|
||||
|
||||
# for i in final_hash:
|
||||
# bmt.print_rank(i, final_hash[i] / sum_total)
|
||||
# bmt.print_rank("=========================================")
|
||||
|
||||
train_info = {
|
||||
"time": tim_usage["init"],
|
||||
"iteration": iteration,
|
||||
"loss": global_loss,
|
||||
"lr": lr_scheduler.current_lr,
|
||||
"lr_scale": int(optim_manager.loss_scale),
|
||||
"time_usage": tim_usage,
|
||||
"mem_usage": mem_usage,
|
||||
"avg_time": avg_time,
|
||||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"num_gpus": global_world_size,
|
||||
"task_loss": task_loss_map,
|
||||
}
|
||||
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
bmt.print_rank(
|
||||
(
|
||||
"| Iter: {:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | avg_time: {:.4f}; cur_time:{:.4f}={:.4f}+{:.4f} |"
|
||||
+ " token/max: {:.4f} | mask/max: {:.4f} | grad_norm: {:.4f} | mem: {:.2f} |"
|
||||
).format(
|
||||
iteration,
|
||||
global_loss,
|
||||
lr_scheduler.current_lr,
|
||||
int(optim_manager.loss_scale),
|
||||
iter_time,
|
||||
tim_usage["optim"] - tim_usage["init"],
|
||||
tim_usage["backward"] - tim_usage["init"],
|
||||
tim_usage["optim"] - tim_usage["backward"],
|
||||
input_length.float().mean() / args.max_length / (args.batch_size if args.flash == "cuda" else 1),
|
||||
(targets >= 0).sum(-1).float().mean()
|
||||
/ args.max_length
|
||||
/ (args.batch_size if args.flash == "cuda" else 1),
|
||||
grad_norm,
|
||||
max(mem_usage["forward"][1], mem_usage["backward"][1]),
|
||||
)
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"| "
|
||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % args.inspect_iters == 0:
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
train_info["model_inspect"] = model_inspect
|
||||
|
||||
if args.log_dir is not None and bmt.rank() == 0:
|
||||
log_mgr.write(**train_info)
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||
for task_name, loss in task_loss_map.items():
|
||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||
|
||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||
if args.save is not None and iteration % args.save_iters == 0:
|
||||
exporter.export(model, dataloader, optimizer, iteration, args, final_save=False)
|
||||
|
||||
if iteration >= args.train_iters:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"train loop err: {e}")
|
||||
raise e
|
||||
finally:
|
||||
dataloader.close()
|
||||
exporter.export(model, dataloader, optimizer, -1, args, final_save=False)
|
||||
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
bmt.print_rank("finish loading")
|
||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,60 +0,0 @@
|
|||
#! /bin/bash
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --mem=512GB
|
||||
|
||||
# use 8 GPU for example, pretrain may need 32 GPU
|
||||
export MASTER_ADDR=`hostname`
|
||||
export MASTER_PORT=12345
|
||||
|
||||
mkdir -p /home/${USERNAME}/logs/debug
|
||||
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
|
||||
|
||||
cd apps/cpm9g
|
||||
CONFIG_NAME="config/11b"
|
||||
# --------------- 运行参数 ---------------
|
||||
OPTS=""
|
||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||
OPTS+=" --batch-size 4"
|
||||
OPTS+=" --train-iters 400000"
|
||||
OPTS+=" --save-iters 250"
|
||||
OPTS+=" --save-name cpm9g_checkpoint"
|
||||
OPTS+=" --max-length 4096"
|
||||
OPTS+=" --lr 1.5e-5"
|
||||
OPTS+=" --inspect-iters 100"
|
||||
OPTS+=" --warmup-iters 2000"
|
||||
OPTS+=" --lr-decay-style noam"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --clip-grad 1.0"
|
||||
OPTS+=" --loss-scale 1048576"
|
||||
OPTS+=" --loss-scale-steps 32"
|
||||
OPTS+=" --offload"
|
||||
OPTS+=" --flash cuda"
|
||||
# OPTS+=" --load-grad"
|
||||
|
||||
# --------------- 写文件路径 ---------------
|
||||
## checkpoint
|
||||
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
|
||||
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
|
||||
|
||||
## logs,/local/logs 等价于 /data/logs(软链)
|
||||
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
|
||||
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
|
||||
|
||||
# --------------- 读文件路径 ---------------
|
||||
OPTS+=" --dataset config/datasets.json"
|
||||
OPTS+=" --load ${CHECKPOINT}"
|
||||
OPTS+=" --start-step 1"
|
||||
|
||||
# --------------- 透传参数 ---------------
|
||||
OPTS+=" $@"
|
||||
|
||||
# --------------- 最终指令 ---------------
|
||||
|
||||
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
|
||||
echo "${CMD}"
|
||||
|
||||
$CMD
|
|
@ -1,59 +0,0 @@
|
|||
#! /bin/bash
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --cpus-per-task=8
|
||||
|
||||
# use 8 GPU for example, pretrain may need 32 GPU
|
||||
export MASTER_ADDR=`hostname`
|
||||
export MASTER_PORT=12345
|
||||
|
||||
mkdir -p /home/${USERNAME}/logs/debug
|
||||
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
|
||||
|
||||
cd apps/cpm9g
|
||||
CONFIG_NAME="config/7b"
|
||||
# --------------- 运行参数 ---------------
|
||||
OPTS=""
|
||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||
OPTS+=" --batch-size 4"
|
||||
OPTS+=" --train-iters 400000"
|
||||
OPTS+=" --save-iters 250"
|
||||
OPTS+=" --save-name cpm9g_checkpoint"
|
||||
OPTS+=" --max-length 4096"
|
||||
OPTS+=" --lr 1.5e-5"
|
||||
OPTS+=" --inspect-iters 100"
|
||||
OPTS+=" --warmup-iters 2000"
|
||||
OPTS+=" --lr-decay-style noam"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --clip-grad 1.0"
|
||||
OPTS+=" --loss-scale 1048576"
|
||||
OPTS+=" --loss-scale-steps 32"
|
||||
OPTS+=" --offload"
|
||||
OPTS+=" --flash cuda"
|
||||
# OPTS+=" --load-grad"
|
||||
|
||||
# --------------- 写文件路径 ---------------
|
||||
## checkpoint
|
||||
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
|
||||
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
|
||||
|
||||
## logs,/local/logs 等价于 /data/logs(软链)
|
||||
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
|
||||
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
|
||||
|
||||
# --------------- 读文件路径 ---------------
|
||||
OPTS+=" --dataset config/datasets.json"
|
||||
OPTS+=" --load ${CHECKPOINT}"
|
||||
OPTS+=" --start-step 1"
|
||||
|
||||
# --------------- 透传参数 ---------------
|
||||
OPTS+=" $@"
|
||||
|
||||
# --------------- 最终指令 ---------------
|
||||
|
||||
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
|
||||
echo "${CMD}"
|
||||
|
||||
$CMD
|
|
@ -1,484 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The OpenBMB team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
||||
from cpm.arguments import get_args
|
||||
from cpm.cpm9g.models import CPM9G
|
||||
from cpm.cpm9g.models import CPM9GConfig
|
||||
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
||||
from cpm.cpm9g.training_tasks import FinetuneDataset
|
||||
from cpm.utils import allgather_objects
|
||||
from cpm.utils import logger
|
||||
import shutil
|
||||
|
||||
def get_tokenizer(args):
|
||||
tokenizer = CPM9GTokenizer(path=args.vocab)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(args):
|
||||
config = CPM9GConfig.from_json_file(args.model_config)
|
||||
if args.flash == "none":
|
||||
config.use_flash_attn = False
|
||||
else:
|
||||
config.use_flash_attn = True
|
||||
if args.flash == "1d":
|
||||
config.flash_attn_mask_shape = "1d"
|
||||
else:
|
||||
config.flash_attn_mask_shape = "2d"
|
||||
if args.flash == "triton":
|
||||
config.flash_impl = "triton"
|
||||
elif args.flash == "cuda":
|
||||
config.flash_impl = "cuda"
|
||||
model = CPM9G(config)
|
||||
if args.load is not None:
|
||||
bmt.init_parameters(model)
|
||||
bmt.synchronize()
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||
bmt.load(model, args.load, strict=False)
|
||||
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
else:
|
||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||
bmt.init_parameters(model)
|
||||
return model
|
||||
|
||||
|
||||
def get_optimizer(args, model):
|
||||
if args.offload:
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
||||
)
|
||||
else:
|
||||
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
if args.load is not None and args.load_grad:
|
||||
start = time.time()
|
||||
print(
|
||||
sum(
|
||||
[
|
||||
1
|
||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||
else 0
|
||||
for i in os.listdir(args.save)
|
||||
]
|
||||
)
|
||||
)
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
1
|
||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||
else 0
|
||||
for i in os.listdir(args.save)
|
||||
]
|
||||
)
|
||||
== bmt.world_size()
|
||||
):
|
||||
file_name = os.path.join(
|
||||
args.save,
|
||||
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
|
||||
)
|
||||
print(file_name)
|
||||
if os.path.exists(file_name):
|
||||
print("start to load grad ckpt {}".format(file_name))
|
||||
states = torch.load(file_name)
|
||||
optimizer.load_state_dict(states)
|
||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||
return optimizer
|
||||
|
||||
|
||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||
r"""
|
||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||
"""
|
||||
|
||||
def get_lr_warmup(self, num_iter) -> float:
|
||||
return self.start_lr * num_iter / self.warmup_iter
|
||||
|
||||
def get_lr_decay(self, num_iter) -> float:
|
||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
if args.lr_decay_iters is None:
|
||||
args.lr_decay_iters = args.train_iters
|
||||
# lr_scheduler = bmt.lr_scheduler.Noam(
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=args.warmup_iters,
|
||||
end_iter=args.lr_decay_iters,
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
start = time.time()
|
||||
model = get_model(args)
|
||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
tokenizer = get_tokenizer(args)
|
||||
bmt.synchronize()
|
||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
optimizer = get_optimizer(args, model)
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def initialize():
|
||||
args = get_args(finetune=True)
|
||||
|
||||
# hack
|
||||
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
|
||||
bmt.init_distributed(checkpointing=False, seed=args.seed)
|
||||
else:
|
||||
bmt.init_distributed(seed=args.seed)
|
||||
if args.save is not None:
|
||||
os.makedirs(args.save, exist_ok=True)
|
||||
# if args.load is not None:
|
||||
# if args.start_step == 0:
|
||||
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
||||
return args
|
||||
|
||||
|
||||
def see_memory(detail=False):
|
||||
if detail:
|
||||
res = torch.cuda.memory_summary()
|
||||
else:
|
||||
res = (
|
||||
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return res
|
||||
|
||||
|
||||
def add_mem_time(info, mem_usage, tim_usage):
|
||||
torch.cuda.synchronize()
|
||||
mem_usage[info] = see_memory()
|
||||
tim_usage[info] = time.time()
|
||||
return mem_usage, tim_usage
|
||||
|
||||
|
||||
class LossSpikeDetector:
|
||||
def __init__(self, log_path: str) -> None:
|
||||
self._last_loss: Dict[str, float] = {}
|
||||
self._last_data: List[Any] = [None]
|
||||
self._log_path = log_path
|
||||
|
||||
def update_data(self, data: Any):
|
||||
self._last_data.append(data)
|
||||
if len(self._last_data) > 2:
|
||||
self._last_data = self._last_data[-2:]
|
||||
|
||||
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
||||
loss_spike_result = []
|
||||
for task, loss in loss_map.items():
|
||||
if task in self._last_loss:
|
||||
if loss > self._last_loss[task] * 3:
|
||||
# loss spike!
|
||||
loss_spike_result.append(
|
||||
{
|
||||
"prev": self._last_loss[task],
|
||||
"curr": loss,
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
self._last_loss[task] = float(loss)
|
||||
if len(loss_spike_result) > 0:
|
||||
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
||||
|
||||
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
||||
return
|
||||
with open(self._log_path, "a", encoding="utf-8") as fp:
|
||||
fp.write("=" * 20)
|
||||
fp.write("\nloss spike at {}\n".format(iteration))
|
||||
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
||||
fp.write("data: \n")
|
||||
for d in data:
|
||||
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
||||
fp.write("\n\n")
|
||||
|
||||
|
||||
def finetune(
|
||||
args,
|
||||
bin_file: str,
|
||||
tokenizer: CPM9GTokenizer,
|
||||
model: CPM9G,
|
||||
optimizer,
|
||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||
):
|
||||
average_time = bmt.utils.AverageRecorder()
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
||||
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
import distutils.version # noqa: F401
|
||||
from tensorboardX import SummaryWriter
|
||||
if not os.path.exists(args.tensorboard):
|
||||
os.makedirs(args.tensorboard)
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
|
||||
global_token_pass = 0.0
|
||||
global_world_size = bmt.world_size()
|
||||
|
||||
for epoch in range(args.epoch):
|
||||
epoch = epoch + 1
|
||||
last_data = None
|
||||
dataloader = FinetuneDataset(
|
||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||
)
|
||||
optim_manager.zero_grad()
|
||||
for iteration, data in enumerate(dataloader):
|
||||
iteration = iteration + 1
|
||||
skip_this_batch = False
|
||||
if data is None:
|
||||
if last_data is None:
|
||||
raise RuntimeError(
|
||||
"Dataset is too small, please use a smaller batch size or sequence length!"
|
||||
)
|
||||
data = last_data # use last data
|
||||
skip_this_batch = True
|
||||
else:
|
||||
last_data = data
|
||||
|
||||
assert data["inputs"].shape[0] == args.batch_size
|
||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
||||
targets = torch.from_numpy(data["target"]).cuda().long()
|
||||
# bmt.print_rank(input_ids[0].tolist())
|
||||
# bmt.print_rank(targets[0].tolist())
|
||||
# bmt.print_rank(data["spans"].tolist())
|
||||
# bmt.print_rank(tokenizer.decode(input_ids[0]))
|
||||
# bmt.print_rank(tokenizer.path)
|
||||
# bmt.synchronize()
|
||||
# exit()
|
||||
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
||||
task_names = data["task_names"]
|
||||
if args.flash == "cuda":
|
||||
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
||||
max_seqlen = data["max_seqlen"]
|
||||
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
||||
else:
|
||||
input_context = torch.zeros_like(input_ids).cuda().bool()
|
||||
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
||||
|
||||
# ===========
|
||||
# optim_manager.zero_grad()
|
||||
# torch.cuda.empty_cache()
|
||||
mem_usage = {}
|
||||
tim_usage = {}
|
||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||
|
||||
# ===========
|
||||
if args.flash == "cuda":
|
||||
logits, _ = model(
|
||||
input_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
else:
|
||||
logits, _ = model(
|
||||
input_ids,
|
||||
input_length,
|
||||
input_context,
|
||||
input_span,
|
||||
)
|
||||
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
||||
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
if skip_this_batch:
|
||||
loss = loss * 0
|
||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||
|
||||
# ===========
|
||||
optim_manager.backward(loss)
|
||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||
|
||||
# ===========
|
||||
if iteration % args.gradient_accumulation_steps == 0:
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
optim_manager.zero_grad()
|
||||
else:
|
||||
grad_norm = None
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
|
||||
# ==========
|
||||
iter_time = tim_usage["optim"] - tim_usage["init"]
|
||||
average_time.record(iter_time)
|
||||
|
||||
with torch.no_grad():
|
||||
task_num = len(task_names)
|
||||
targets_tmp = targets.expand(task_num, -1, -1)
|
||||
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
|
||||
targets_tmp = torch.where(
|
||||
task_ids == task,
|
||||
targets_tmp,
|
||||
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
|
||||
)
|
||||
|
||||
task_loss_map: Dict[str, float] = {}
|
||||
if not skip_this_batch:
|
||||
for i in range(task_num):
|
||||
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
|
||||
task_loss_map[task_names[i]] = task_loss.item()
|
||||
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
||||
|
||||
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
||||
for local_task_loss_map in gatherd_task_loss_map:
|
||||
for task_name, task_loss in local_task_loss_map.items():
|
||||
if task_name not in global_task_loss_map:
|
||||
global_task_loss_map[task_name] = []
|
||||
global_task_loss_map[task_name].append(task_loss)
|
||||
|
||||
task_loss_map = {}
|
||||
for task_name in sorted(list(global_task_loss_map.keys())):
|
||||
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
|
||||
task_loss_map[task_name] = avg_loss
|
||||
|
||||
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
||||
avg_time = average_time.value
|
||||
|
||||
train_info = {
|
||||
"time": tim_usage["init"],
|
||||
"epoch": epoch,
|
||||
"iteration": iteration,
|
||||
"loss": task_loss_map[args.task_name],
|
||||
"lr": lr_scheduler.current_lr,
|
||||
"lr_scale": int(optim_manager.loss_scale),
|
||||
"time_usage": tim_usage,
|
||||
"mem_usage": mem_usage,
|
||||
"avg_time": avg_time,
|
||||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
||||
"grad_norm": 0 if grad_norm == None else grad_norm.item(),
|
||||
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"num_gpus": global_world_size,
|
||||
"task_loss": task_loss_map,
|
||||
}
|
||||
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
bmt.print_rank(
|
||||
(
|
||||
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
|
||||
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
|
||||
).format(
|
||||
epoch,
|
||||
iteration,
|
||||
args.train_iters,
|
||||
task_loss_map[args.task_name],
|
||||
lr_scheduler.current_lr,
|
||||
int(optim_manager.loss_scale),
|
||||
avg_time,
|
||||
input_length.float().mean() / args.max_length,
|
||||
(targets >= 0).sum(-1).float().mean() / args.max_length,
|
||||
0 if grad_norm == None else grad_norm,
|
||||
)
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"| "
|
||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||
+ " |"
|
||||
)
|
||||
if iteration % args.inspect_iters == 0:
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
train_info["model_inspect"] = model_inspect
|
||||
|
||||
# save_folder_name = f"{args.save}{epoch}{iteration}"
|
||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
|
||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||
# if bmt.rank() == 0:
|
||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
|
||||
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", 0 if grad_norm == None else grad_norm.item(), iteration)
|
||||
|
||||
# if iteration % 10 == 0:
|
||||
# save_folder_name = f"{args.save}{epoch}"
|
||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||
# bmt.save(model, model_fname)
|
||||
# if bmt.rank() == 0:
|
||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||
|
||||
save_folder_name = f"{args.save}-epoch-{epoch}"
|
||||
model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
||||
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||
bmt.save(model, model_fname)
|
||||
if bmt.rank() == 0:
|
||||
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
# To Be Specified
|
||||
bin_file = args.dataset
|
||||
bmt.print_rank(f"dataset: {bin_file}")
|
||||
|
||||
tokenizer = get_tokenizer(args)
|
||||
dataloader = FinetuneDataset(
|
||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||
)
|
||||
bmt.print_rank(f"#batch: {len(dataloader)}")
|
||||
total_steps = len(dataloader) * args.epoch
|
||||
|
||||
setattr(args, 'train_iters', int(total_steps))
|
||||
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
|
||||
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
|
||||
bmt.print_rank("finish loading")
|
||||
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,55 +0,0 @@
|
|||
#! /bin/bash
|
||||
#SBATCH --partition=gpu3-1
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --cpus-per-task=8
|
||||
|
||||
export MASTER_ADDR=g3002
|
||||
export MASTER_PORT=12345
|
||||
|
||||
CPM_PATH="/data/groups/QY_LLM_Core/projects/202311-release/Models/11B-Chat/9G-Train"
|
||||
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
|
||||
EXP_PATH=./exp
|
||||
mkdir -p $EXP_PATH
|
||||
MODEL_NAME="cpm9g-11b-sft"
|
||||
|
||||
OPTS=""
|
||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||
|
||||
OPTS+=" --train-iters 10000"
|
||||
OPTS+=" --inspect-iters 200"
|
||||
OPTS+=" --warmup-iters 500"
|
||||
|
||||
OPTS+=" --lr-decay-style cosine"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --clip-grad 1.0"
|
||||
OPTS+=" --loss-scale 1048576"
|
||||
OPTS+=" --max-loss-scale 33554432"
|
||||
OPTS+=" --min-loss-scale 1"
|
||||
OPTS+=" --loss-scale-steps 32"
|
||||
|
||||
OPTS+=" --offload"
|
||||
OPTS+=" --batch-size 4"
|
||||
OPTS+=" --max-length 4096"
|
||||
OPTS+=" --lr 2e-5"
|
||||
OPTS+=" --start-step 0"
|
||||
OPTS+=" --epoch 8"
|
||||
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
|
||||
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
|
||||
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
|
||||
OPTS+=" --save ${EXP_PATH}/checkpoints"
|
||||
OPTS+=" --save-name ${MODEL_NAME}"
|
||||
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
|
||||
# OPTS+=" --flash triton"
|
||||
# OPTS+=" --flash cuda"
|
||||
# OPTS+=" --load-grad"
|
||||
|
||||
OPTS+=" $@"
|
||||
|
||||
|
||||
CMD="torchrun --nnodes=4 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g.py ${OPTS}"
|
||||
|
||||
echo "${CMD}"
|
||||
$CMD
|
|
@ -1,515 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The OpenBMB team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
||||
from cpm.arguments import get_args
|
||||
from cpm.cpm9g.models import CPM9G
|
||||
from cpm.cpm9g.models import CPM9GConfig
|
||||
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
||||
from cpm.cpm9g.training_tasks import FinetuneDataset
|
||||
from cpm.utils import allgather_objects
|
||||
from cpm.utils import logger
|
||||
import shutil
|
||||
|
||||
import opendelta as od
|
||||
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
|
||||
from opendelta.utils.inspect import inspect_optimizer_statistics
|
||||
from bigmodelvis import Visualization
|
||||
|
||||
|
||||
def get_tokenizer(args):
|
||||
tokenizer = CPM9GTokenizer(path=args.vocab)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(args):
|
||||
config = CPM9GConfig.from_json_file(args.model_config)
|
||||
if args.flash == "none":
|
||||
config.use_flash_attn = False
|
||||
else:
|
||||
config.use_flash_attn = True
|
||||
if args.flash == "1d":
|
||||
config.flash_attn_mask_shape = "1d"
|
||||
else:
|
||||
config.flash_attn_mask_shape = "2d"
|
||||
if args.flash == "triton":
|
||||
config.flash_impl = "triton"
|
||||
elif args.flash == "cuda":
|
||||
config.flash_impl = "cuda"
|
||||
model = CPM9G(config)
|
||||
if args.load is not None:
|
||||
bmt.init_parameters(model)
|
||||
bmt.synchronize()
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||
bmt.load(model, args.load, strict=False)
|
||||
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
else:
|
||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||
bmt.init_parameters(model)
|
||||
|
||||
if args.delta_type != None:
|
||||
from bigmodelvis import Visualization #0813
|
||||
if bmt.rank() == 0:
|
||||
Visualization(model).structure_graph()
|
||||
print("finetuned layers: ")
|
||||
print(args.lora_layer)
|
||||
|
||||
if args.delta_type == "lora":
|
||||
delta_model = LoraModel(backbone_model=model, modified_modules=args.lora_layer, backend='bmt', lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
|
||||
|
||||
if bmt.rank() == 0:
|
||||
print("Before freeze: ")
|
||||
delta_model.log()
|
||||
|
||||
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
|
||||
|
||||
if bmt.rank() == 0:
|
||||
print("After freeze: ")
|
||||
delta_model.log()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_optimizer(args, model):
|
||||
if args.offload:
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
||||
)
|
||||
else:
|
||||
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
if args.load is not None and args.load_grad:
|
||||
start = time.time()
|
||||
print(
|
||||
sum(
|
||||
[
|
||||
1
|
||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||
else 0
|
||||
for i in os.listdir(args.save)
|
||||
]
|
||||
)
|
||||
)
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
1
|
||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||
else 0
|
||||
for i in os.listdir(args.save)
|
||||
]
|
||||
)
|
||||
== bmt.world_size()
|
||||
):
|
||||
file_name = os.path.join(
|
||||
args.save,
|
||||
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
|
||||
)
|
||||
print(file_name)
|
||||
if os.path.exists(file_name):
|
||||
print("start to load grad ckpt {}".format(file_name))
|
||||
states = torch.load(file_name)
|
||||
optimizer.load_state_dict(states)
|
||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||
return optimizer
|
||||
|
||||
|
||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||
r"""
|
||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||
"""
|
||||
|
||||
def get_lr_warmup(self, num_iter) -> float:
|
||||
return self.start_lr * num_iter / self.warmup_iter
|
||||
|
||||
def get_lr_decay(self, num_iter) -> float:
|
||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
if args.lr_decay_iters is None:
|
||||
args.lr_decay_iters = args.train_iters
|
||||
# lr_scheduler = bmt.lr_scheduler.Noam(
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=args.warmup_iters,
|
||||
end_iter=args.lr_decay_iters,
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
start = time.time()
|
||||
model = get_model(args)
|
||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
tokenizer = get_tokenizer(args)
|
||||
bmt.synchronize()
|
||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
optimizer = get_optimizer(args, model)
|
||||
|
||||
if args.delta_type != None:
|
||||
inspect_optimizer_statistics(optimizer)
|
||||
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def initialize():
|
||||
args = get_args(finetune=True)
|
||||
|
||||
# hack
|
||||
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
|
||||
bmt.init_distributed(checkpointing=False, seed=args.seed)
|
||||
else:
|
||||
bmt.init_distributed(seed=args.seed)
|
||||
if args.save is not None:
|
||||
os.makedirs(args.save, exist_ok=True)
|
||||
# if args.load is not None:
|
||||
# if args.start_step == 0:
|
||||
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
||||
return args
|
||||
|
||||
|
||||
def see_memory(detail=False):
|
||||
if detail:
|
||||
res = torch.cuda.memory_summary()
|
||||
else:
|
||||
res = (
|
||||
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return res
|
||||
|
||||
|
||||
def add_mem_time(info, mem_usage, tim_usage):
|
||||
torch.cuda.synchronize()
|
||||
mem_usage[info] = see_memory()
|
||||
tim_usage[info] = time.time()
|
||||
return mem_usage, tim_usage
|
||||
|
||||
|
||||
class LossSpikeDetector:
|
||||
def __init__(self, log_path: str) -> None:
|
||||
self._last_loss: Dict[str, float] = {}
|
||||
self._last_data: List[Any] = [None]
|
||||
self._log_path = log_path
|
||||
|
||||
def update_data(self, data: Any):
|
||||
self._last_data.append(data)
|
||||
if len(self._last_data) > 2:
|
||||
self._last_data = self._last_data[-2:]
|
||||
|
||||
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
||||
loss_spike_result = []
|
||||
for task, loss in loss_map.items():
|
||||
if task in self._last_loss:
|
||||
if loss > self._last_loss[task] * 3:
|
||||
# loss spike!
|
||||
loss_spike_result.append(
|
||||
{
|
||||
"prev": self._last_loss[task],
|
||||
"curr": loss,
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
self._last_loss[task] = float(loss)
|
||||
if len(loss_spike_result) > 0:
|
||||
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
||||
|
||||
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
||||
return
|
||||
with open(self._log_path, "a", encoding="utf-8") as fp:
|
||||
fp.write("=" * 20)
|
||||
fp.write("\nloss spike at {}\n".format(iteration))
|
||||
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
||||
fp.write("data: \n")
|
||||
for d in data:
|
||||
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
||||
fp.write("\n\n")
|
||||
|
||||
|
||||
def finetune(
|
||||
args,
|
||||
bin_file: str,
|
||||
tokenizer: CPM9GTokenizer,
|
||||
model: CPM9G,
|
||||
optimizer,
|
||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||
):
|
||||
average_time = bmt.utils.AverageRecorder()
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
||||
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
import distutils.version # noqa: F401
|
||||
from tensorboardX import SummaryWriter
|
||||
if not os.path.exists(args.tensorboard):
|
||||
os.makedirs(args.tensorboard)
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
|
||||
global_token_pass = 0.0
|
||||
global_world_size = bmt.world_size()
|
||||
|
||||
for epoch in range(args.epoch):
|
||||
epoch = epoch + 1
|
||||
last_data = None
|
||||
dataloader = FinetuneDataset(
|
||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||
)
|
||||
for iteration, data in enumerate(dataloader):
|
||||
iteration = iteration + 1
|
||||
skip_this_batch = False
|
||||
if data is None:
|
||||
if last_data is None:
|
||||
raise RuntimeError(
|
||||
"Dataset is too small, please use a smaller batch size or sequence length!"
|
||||
)
|
||||
data = last_data # use last data
|
||||
skip_this_batch = True
|
||||
else:
|
||||
last_data = data
|
||||
|
||||
assert data["inputs"].shape[0] == args.batch_size
|
||||
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
||||
targets = torch.from_numpy(data["target"]).cuda().long()
|
||||
# bmt.print_rank(input_ids[0].tolist())
|
||||
# bmt.print_rank(targets[0].tolist())
|
||||
# bmt.print_rank(data["spans"].tolist())
|
||||
# bmt.print_rank(tokenizer.decode(input_ids[0]))
|
||||
# bmt.print_rank(tokenizer.path)
|
||||
# bmt.synchronize()
|
||||
# exit()
|
||||
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
||||
task_names = data["task_names"]
|
||||
if args.flash == "cuda":
|
||||
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
||||
max_seqlen = data["max_seqlen"]
|
||||
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
||||
else:
|
||||
input_context = torch.zeros_like(input_ids).cuda().bool()
|
||||
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
||||
|
||||
# ===========
|
||||
optim_manager.zero_grad()
|
||||
# torch.cuda.empty_cache()
|
||||
mem_usage = {}
|
||||
tim_usage = {}
|
||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||
|
||||
# ===========
|
||||
if args.flash == "cuda":
|
||||
logits, _ = model(
|
||||
input_ids,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
else:
|
||||
logits, _ = model(
|
||||
input_ids,
|
||||
input_length,
|
||||
input_context,
|
||||
input_span,
|
||||
)
|
||||
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
||||
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
if skip_this_batch:
|
||||
loss = loss * 0
|
||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||
|
||||
# ===========
|
||||
optim_manager.backward(loss)
|
||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||
|
||||
# ===========
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
|
||||
# ==========
|
||||
iter_time = tim_usage["optim"] - tim_usage["init"]
|
||||
average_time.record(iter_time)
|
||||
|
||||
with torch.no_grad():
|
||||
task_num = len(task_names)
|
||||
targets_tmp = targets.expand(task_num, -1, -1)
|
||||
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
|
||||
targets_tmp = torch.where(
|
||||
task_ids == task,
|
||||
targets_tmp,
|
||||
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
|
||||
)
|
||||
|
||||
task_loss_map: Dict[str, float] = {}
|
||||
if not skip_this_batch:
|
||||
for i in range(task_num):
|
||||
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
|
||||
task_loss_map[task_names[i]] = task_loss.item()
|
||||
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
||||
|
||||
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
||||
for local_task_loss_map in gatherd_task_loss_map:
|
||||
for task_name, task_loss in local_task_loss_map.items():
|
||||
if task_name not in global_task_loss_map:
|
||||
global_task_loss_map[task_name] = []
|
||||
global_task_loss_map[task_name].append(task_loss)
|
||||
|
||||
task_loss_map = {}
|
||||
for task_name in sorted(list(global_task_loss_map.keys())):
|
||||
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
|
||||
task_loss_map[task_name] = avg_loss
|
||||
|
||||
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
||||
avg_time = average_time.value
|
||||
|
||||
train_info = {
|
||||
"time": tim_usage["init"],
|
||||
"epoch": epoch,
|
||||
"iteration": iteration,
|
||||
"loss": task_loss_map[args.task_name],
|
||||
"lr": lr_scheduler.current_lr,
|
||||
"lr_scale": int(optim_manager.loss_scale),
|
||||
"time_usage": tim_usage,
|
||||
"mem_usage": mem_usage,
|
||||
"avg_time": avg_time,
|
||||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"num_gpus": global_world_size,
|
||||
"task_loss": task_loss_map,
|
||||
}
|
||||
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
bmt.print_rank(
|
||||
(
|
||||
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
|
||||
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
|
||||
).format(
|
||||
epoch,
|
||||
iteration,
|
||||
args.train_iters,
|
||||
task_loss_map[args.task_name],
|
||||
lr_scheduler.current_lr,
|
||||
int(optim_manager.loss_scale),
|
||||
avg_time,
|
||||
input_length.float().mean() / args.max_length,
|
||||
(targets >= 0).sum(-1).float().mean() / args.max_length,
|
||||
grad_norm,
|
||||
)
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"| "
|
||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||
+ " |"
|
||||
)
|
||||
if iteration % args.inspect_iters == 0:
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
train_info["model_inspect"] = model_inspect
|
||||
|
||||
# save_folder_name = f"{args.save}{epoch}{iteration}"
|
||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
|
||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||
# if bmt.rank() == 0:
|
||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
|
||||
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", grad_norm.item(), iteration)
|
||||
|
||||
# if iteration % 10 == 0:
|
||||
# save_folder_name = f"{args.save}{epoch}"
|
||||
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
||||
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||
# bmt.save(model, model_fname)
|
||||
# if bmt.rank() == 0:
|
||||
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||
|
||||
save_folder_name = f"{args.save}/epoch-{epoch}"
|
||||
model_fname = os.path.join(save_folder_name, f"{args.save_name}.pt")
|
||||
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||
state_dict = model.state_dict()
|
||||
if args.delta_type == None or args.save_origin_model == True :
|
||||
print("saving base model...")
|
||||
bmt.save(model, model_fname)
|
||||
if args.delta_type != None and bmt.rank() == 0:
|
||||
print("saving delta model...")
|
||||
torch.save(state_dict, os.path.join(save_folder_name, f"{args.save_name}-delta.pt"))
|
||||
if bmt.rank() == 0:
|
||||
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
# To Be Specified
|
||||
bin_file = args.dataset
|
||||
bmt.print_rank(f"dataset: {bin_file}")
|
||||
|
||||
tokenizer = get_tokenizer(args)
|
||||
dataloader = FinetuneDataset(
|
||||
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||
)
|
||||
bmt.print_rank(f"#batch: {len(dataloader)}")
|
||||
total_steps = len(dataloader) * args.epoch
|
||||
|
||||
setattr(args, 'train_iters', int(total_steps))
|
||||
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
|
||||
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
|
||||
bmt.print_rank("finish loading")
|
||||
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,63 +0,0 @@
|
|||
#! /bin/bash
|
||||
#SBATCH --partition=gpu3-1
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --cpus-per-task=8
|
||||
|
||||
export MASTER_ADDR=`hostname`
|
||||
export MASTER_PORT=12345
|
||||
echo $MASTER_ADDR
|
||||
|
||||
CPM_PATH="./9G-Train"
|
||||
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
|
||||
EXP_PATH=./exp
|
||||
mkdir -p $EXP_PATH
|
||||
MODEL_NAME="cpm9g-11b-sft"
|
||||
|
||||
OPTS=""
|
||||
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||
|
||||
OPTS+=" --train-iters 10000"
|
||||
OPTS+=" --inspect-iters 200"
|
||||
OPTS+=" --warmup-iters 500"
|
||||
|
||||
OPTS+=" --lr-decay-style cosine"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --clip-grad 1.0"
|
||||
OPTS+=" --loss-scale 1048576"
|
||||
OPTS+=" --max-loss-scale 33554432"
|
||||
OPTS+=" --min-loss-scale 1"
|
||||
OPTS+=" --loss-scale-steps 32"
|
||||
|
||||
OPTS+=" --offload"
|
||||
OPTS+=" --batch-size 4"
|
||||
OPTS+=" --max-length 4096"
|
||||
OPTS+=" --lr 2e-5"
|
||||
OPTS+=" --start-step 0"
|
||||
OPTS+=" --epoch 8"
|
||||
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
|
||||
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
|
||||
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
|
||||
OPTS+=" --save ${EXP_PATH}/checkpoints"
|
||||
OPTS+=" --save-name ${MODEL_NAME}"
|
||||
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
|
||||
# OPTS+=" --flash triton"
|
||||
# OPTS+=" --flash cuda"
|
||||
# OPTS+=" --load-grad"
|
||||
|
||||
OPTS+=" --delta-tuning" #开启delta-tuning
|
||||
OPTS+=" --delta-type lora" #目前仅支持lora
|
||||
OPTS+=" --lora-r 64" #lora矩阵的维度,默认为8
|
||||
OPTS+=" --lora-dropout 0.05" #默认为0
|
||||
OPTS+=" --lora-alpha 64" #lora对原模型的影响比例,默认为8
|
||||
OPTS+=" --lora-layer project_q project_v project_k w_0 w_1 w_out" #参与lora的线性层,默认为project_q,project_k
|
||||
OPTS+=" --save-origin-model" #是否在每个epoch储存基座模型
|
||||
|
||||
OPTS+=" $@"
|
||||
|
||||
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g_delta.py ${OPTS}"
|
||||
|
||||
echo "${CMD}"
|
||||
$CMD
|
|
@ -1,5 +0,0 @@
|
|||
from .models import CPM9G
|
||||
from .models import CPM9GConfig
|
||||
|
||||
from .tokenizers import CPM9GTokenizer
|
||||
from .training_tasks import MixedDataset
|
|
@ -1,16 +0,0 @@
|
|||
{
|
||||
"vocab_size": 119696,
|
||||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"use_flash_attn": false,
|
||||
"flash_attn_mask_shape": "1d",
|
||||
"dim_model": 4096,
|
||||
"dim_ff": 12288,
|
||||
"dim_head": 128,
|
||||
"num_heads": 32,
|
||||
"num_kv_heads": 32,
|
||||
"num_layers": 32,
|
||||
"activate_fn": "silu",
|
||||
"scale": false
|
||||
}
|
|
@ -1,658 +0,0 @@
|
|||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...generation import apply_repetition_penalty
|
||||
from ...generation import BeamHypotheses
|
||||
from ...generation import top_k_top_p_filtering
|
||||
|
||||
from ...utils import pad
|
||||
from ..models import CPM9GTorch
|
||||
from ..tokenizers.cpm9g import CPM9GTokenizer
|
||||
|
||||
|
||||
class CPM9GGeneration:
|
||||
def __init__(
|
||||
self, model: CPM9GTorch, tokenizer: CPM9GTokenizer, max_in_len=1024, use_nbce: bool = False
|
||||
):
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_in_len = max_in_len
|
||||
|
||||
def _convert_to_tensors(self, data: Any):
|
||||
input_ids = self.tokenizer.encode(
|
||||
data["input"]
|
||||
) # [self.tokenizer.bos_token_id] + self.tokenizer.encode(data["input"])
|
||||
model_input = {}
|
||||
|
||||
model_input["input_ids"] = torch.tensor(input_ids[: self.max_in_len], dtype=torch.int32).unsqueeze(0)
|
||||
model_input["context"] = torch.zeros(
|
||||
(model_input["input_ids"].shape[0], model_input["input_ids"].shape[1]), dtype=torch.int16
|
||||
)
|
||||
model_input["span"] = torch.ones((model_input["input_ids"].shape[1],), dtype=torch.int16).unsqueeze(0)
|
||||
model_input["length"] = torch.tensor([model_input["input_ids"].shape[1]], dtype=torch.int16).unsqueeze(0)
|
||||
return model_input
|
||||
|
||||
def _process_list(self, data_list: List[Any]):
|
||||
input_tensors = list(map(self._convert_to_tensors, data_list))
|
||||
keys = set(input_tensors[0].keys())
|
||||
padded = {}
|
||||
for key in keys:
|
||||
padded[key] = pad(input_tensors, key, padding_side="left").cuda()
|
||||
return padded
|
||||
|
||||
def generate(self, data_list, **kwargs):
|
||||
origin_data_list = data_list.copy()
|
||||
model_inputs = self._process_list(data_list)
|
||||
with torch.inference_mode():
|
||||
result_ids = self._decode(model_inputs, **kwargs)
|
||||
|
||||
return result_ids
|
||||
|
||||
def _decode(self, model_inputs, **kwargs):
|
||||
raise NotImplementedError("_decode is not implemented.")
|
||||
|
||||
|
||||
class CPM9GBeamSearch(CPM9GGeneration):
|
||||
def _decode(
|
||||
self,
|
||||
model_inputs,
|
||||
beam_size=4,
|
||||
max_length=100,
|
||||
repetition_penalty=1.2,
|
||||
repetition_window=None,
|
||||
):
|
||||
"""
|
||||
Beam search
|
||||
Args:
|
||||
model_inputs (dict): input ids.
|
||||
beam_size (int, optional, defaults to 3): beam size of beam search.
|
||||
generate_length (int, optional, defaults to 100): maximum generation length.
|
||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||
""" # noqa: E501
|
||||
# generate_length + 1 for EOS token
|
||||
max_length += 1
|
||||
# expand dimmension
|
||||
batch_size = model_inputs["input_ids"].size(0)
|
||||
input: torch.Tensor = (
|
||||
model_inputs["input_ids"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
length = (
|
||||
model_inputs["length"]
|
||||
.squeeze(1)
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size)
|
||||
.contiguous()
|
||||
.view(
|
||||
batch_size * beam_size,
|
||||
)
|
||||
)
|
||||
span: torch.Tensor = (
|
||||
model_inputs["span"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
context: torch.Tensor = (
|
||||
model_inputs["context"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1)
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
pred_start_index = input.size(-1)
|
||||
past_key_values = None
|
||||
|
||||
for i in range(max_length + 1):
|
||||
if i == 0:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input,
|
||||
context=context,
|
||||
span=span,
|
||||
length=length,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input[:, -1:],
|
||||
context=context,
|
||||
span=span,
|
||||
length=length,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
# skip all steps when we are done with each sentence
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# (batch * beam, seqlen, model_dim)
|
||||
logits = logits[:, -1, :]
|
||||
if i == 0:
|
||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
||||
apply_repetition_penalty(
|
||||
logits,
|
||||
batch_size,
|
||||
beam_size,
|
||||
input,
|
||||
repetition_penalty,
|
||||
pred_start_index,
|
||||
input.size(-1) - 1,
|
||||
repetition_window,
|
||||
)
|
||||
|
||||
scores = F.log_softmax(logits, dim=-1)
|
||||
|
||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
next_scores = next_scores.view(batch_size, -1) # (batch_size, beam_size * vocab_size)
|
||||
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
|
||||
|
||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * beam_size)
|
||||
next_batch_beam = []
|
||||
|
||||
for sent_id in range(batch_size):
|
||||
# if we are done with this sentence
|
||||
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
|
||||
if done[sent_id]:
|
||||
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
|
||||
continue
|
||||
|
||||
# next sentence beam content
|
||||
next_sent_beam = []
|
||||
|
||||
# next words for this sentence
|
||||
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
||||
# get beam and word IDs
|
||||
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
|
||||
word_id = idx % scores.size(-1)
|
||||
|
||||
# end of sentence, or next word
|
||||
if word_id == self.tokenizer.bos_token_id or i == max_length:
|
||||
generated_hyps[sent_id].add(
|
||||
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
|
||||
value.item(),
|
||||
)
|
||||
else:
|
||||
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == beam_size:
|
||||
break
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == 0 if i == max_length else beam_size
|
||||
if len(next_sent_beam) == 0:
|
||||
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == beam_size * (sent_id + 1)
|
||||
|
||||
# we have reached the last step
|
||||
if i == max_length:
|
||||
break
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * beam_size
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_words = input.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
|
||||
# re-order batch and internal states
|
||||
input = input[beam_idx, :]
|
||||
|
||||
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
|
||||
for key_value_layer in past_key_values["buffer"]:
|
||||
if key_value_layer is not None:
|
||||
key_value_layer[0] = key_value_layer[0][beam_idx]
|
||||
key_value_layer[1] = key_value_layer[1][beam_idx]
|
||||
|
||||
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
|
||||
context = torch.cat(
|
||||
[context, context[:, -1:]],
|
||||
dim=-1,
|
||||
)
|
||||
length += 1
|
||||
|
||||
span = torch.cat([span, span[:, -1:]], dim=-1)
|
||||
|
||||
# select the best hypotheses
|
||||
|
||||
results = []
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||
results.append(best_hyp)
|
||||
|
||||
result_text = list(map(self.tokenizer.decode, results))
|
||||
return result_text
|
||||
|
||||
|
||||
class CPM9GRandomSampling(CPM9GGeneration):
|
||||
def _decode(
|
||||
self,
|
||||
model_inputs,
|
||||
max_length=100,
|
||||
top_k=0,
|
||||
top_p=0.9,
|
||||
temperature=0.9,
|
||||
repetition_penalty=1.0,
|
||||
repetition_window=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Top-k and top-p sampling.
|
||||
Args:
|
||||
model_inputs (dict): input ids
|
||||
generate_length (int, optional, defaults to 100): maximum generation length
|
||||
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
|
||||
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
|
||||
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
|
||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||
""" # noqa: E501
|
||||
# generate_length + 1 for EOS token
|
||||
max_length += 1
|
||||
|
||||
input = model_inputs["input_ids"]
|
||||
context = model_inputs["context"]
|
||||
|
||||
length = model_inputs["length"].squeeze(1)
|
||||
span = model_inputs["span"]
|
||||
batch_size = input.size(0)
|
||||
|
||||
pred_start_index = input.size(-1)
|
||||
past_key_values = None
|
||||
done = [False for _ in range(batch_size)]
|
||||
results = [None for _ in range(batch_size)]
|
||||
for i in range(max_length):
|
||||
if i == 0:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input,
|
||||
context=context,
|
||||
length=length,
|
||||
span=span,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input[:, -1:],
|
||||
context=context,
|
||||
length=length,
|
||||
span=span,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if i == 0:
|
||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
||||
|
||||
apply_repetition_penalty(
|
||||
logits,
|
||||
batch_size,
|
||||
1,
|
||||
input,
|
||||
repetition_penalty,
|
||||
pred_start_index,
|
||||
input.size(-1) - 1,
|
||||
repetition_window,
|
||||
)
|
||||
|
||||
logits = logits / temperature
|
||||
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
for idx in range(batch_size):
|
||||
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
|
||||
done[idx] = True
|
||||
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
|
||||
|
||||
if sum(done) == batch_size:
|
||||
break
|
||||
|
||||
# update input ids
|
||||
input = torch.cat([input, next_token], dim=-1)
|
||||
length += 1
|
||||
|
||||
context = torch.cat(
|
||||
[context, context[:, -1:]],
|
||||
dim=-1,
|
||||
)
|
||||
span = torch.cat(
|
||||
[span, span[:, -1:]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
result_text = list(map(self.tokenizer.decode, results))
|
||||
return result_text
|
||||
|
||||
|
||||
class CPM9GBeamSearchNBCE(CPM9GGeneration):
|
||||
def _decode(
|
||||
self,
|
||||
model_inputs,
|
||||
beam_size=5,
|
||||
max_length=100,
|
||||
repetition_penalty=1.0,
|
||||
repetition_window=None,
|
||||
):
|
||||
"""
|
||||
Beam search
|
||||
Args:
|
||||
model_inputs (dict): input ids.
|
||||
beam_size (int, optional, defaults to 3): beam size of beam search.
|
||||
generate_length (int, optional, defaults to 100): maximum generation length.
|
||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||
""" # noqa: E501
|
||||
# generate_length + 1 for EOS token
|
||||
max_length += 1
|
||||
|
||||
# expand dimmension
|
||||
batch_size = model_inputs["input_ids"].size(0)
|
||||
input: torch.Tensor = (
|
||||
model_inputs["input_ids"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
length = (
|
||||
model_inputs["length"]
|
||||
.squeeze(1)
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size)
|
||||
.contiguous()
|
||||
.view(
|
||||
batch_size * beam_size,
|
||||
)
|
||||
)
|
||||
span: torch.Tensor = (
|
||||
model_inputs["span"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
context: torch.Tensor = (
|
||||
model_inputs["context"]
|
||||
.unsqueeze(1)
|
||||
.expand(batch_size, beam_size, -1)
|
||||
.contiguous()
|
||||
.view(batch_size * beam_size, -1)
|
||||
)
|
||||
|
||||
done = [False]
|
||||
|
||||
beam_scores = torch.zeros((1, beam_size), dtype=torch.float, device=input.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1)
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(1)
|
||||
]
|
||||
|
||||
pred_start_index = input.size(-1)
|
||||
past_key_values = None
|
||||
|
||||
for i in range(max_length + 1):
|
||||
if i == 0:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input,
|
||||
context=context,
|
||||
span=span,
|
||||
length=length,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input[:, -1:],
|
||||
context=context,
|
||||
span=span,
|
||||
length=length,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
# skip all steps when we are done with each sentence
|
||||
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# (batch * beam, seqlen, model_dim)
|
||||
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
|
||||
logits = NBCE(logits) # [vocab_size]
|
||||
logits = logits.tile(beam_size, 1)
|
||||
|
||||
if i == 0:
|
||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||
apply_repetition_penalty(
|
||||
logits,
|
||||
1,
|
||||
beam_size,
|
||||
input,
|
||||
repetition_penalty,
|
||||
pred_start_index,
|
||||
input.size(-1) - 1,
|
||||
repetition_window,
|
||||
)
|
||||
|
||||
scores = F.log_softmax(logits, dim=-1)
|
||||
|
||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
next_scores = next_scores.view(1, -1) # (batch_size, beam_size * vocab_size)
|
||||
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
|
||||
|
||||
assert next_scores.size() == next_words.size() == (1, 2 * beam_size)
|
||||
next_batch_beam = []
|
||||
|
||||
for sent_id in range(1):
|
||||
# if we are done with this sentence
|
||||
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
|
||||
if done[sent_id]:
|
||||
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
|
||||
continue
|
||||
|
||||
# next sentence beam content
|
||||
next_sent_beam = []
|
||||
|
||||
# next words for this sentence
|
||||
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
||||
# get beam and word IDs
|
||||
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
|
||||
word_id = idx % scores.size(-1)
|
||||
|
||||
# end of sentence, or next word
|
||||
if word_id == self.tokenizer.bos_token_id or i == max_length:
|
||||
generated_hyps[sent_id].add(
|
||||
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
|
||||
value.item(),
|
||||
)
|
||||
else:
|
||||
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == beam_size:
|
||||
break
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == 0 if i == max_length else beam_size
|
||||
if len(next_sent_beam) == 0:
|
||||
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == beam_size * (sent_id + 1)
|
||||
|
||||
# we have reached the last step
|
||||
if i == max_length:
|
||||
break
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * beam_size
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_words = input.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
|
||||
beam_idx *= batch_size
|
||||
# re-order batch and internal states
|
||||
input = input[beam_idx, :]
|
||||
|
||||
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
|
||||
for key_value_layer in past_key_values["buffer"]:
|
||||
if key_value_layer is not None:
|
||||
key_value_layer[0] = key_value_layer[0][beam_idx]
|
||||
key_value_layer[1] = key_value_layer[1][beam_idx]
|
||||
|
||||
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
|
||||
context = torch.cat(
|
||||
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
|
||||
dim=-1,
|
||||
)
|
||||
length = past_key_values["buffer_length"]
|
||||
length = torch.cat(
|
||||
[length, torch.ones((length.size(0), 1), dtype=torch.int16, device=length.device)],
|
||||
dim=-1,
|
||||
)
|
||||
span = torch.cat([span, span[:, -1:]], dim=-1)
|
||||
|
||||
# select the best hypotheses
|
||||
results = []
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||
results.append(best_hyp)
|
||||
|
||||
result_text = list(map(self.tokenizer.decode, results))
|
||||
return result_text
|
||||
|
||||
|
||||
class CPM9GRandomSamplingNBCE(CPM9GGeneration):
|
||||
def _decode(
|
||||
self,
|
||||
model_inputs,
|
||||
max_length=100,
|
||||
top_k=0,
|
||||
top_p=0.9,
|
||||
temperature=0.9,
|
||||
repetition_penalty=1.0,
|
||||
repetition_window=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Top-k and top-p sampling.
|
||||
Args:
|
||||
model_inputs (dict): input ids
|
||||
generate_length (int, optional, defaults to 100): maximum generation length
|
||||
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
|
||||
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
|
||||
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
|
||||
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||
""" # noqa: E501
|
||||
# generate_length + 1 for EOS token
|
||||
max_length += 1
|
||||
|
||||
input = model_inputs["input_ids"]
|
||||
context = model_inputs["context"]
|
||||
|
||||
length = model_inputs["length"].squeeze(1)
|
||||
span = model_inputs["span"]
|
||||
batch_size = input.size(0)
|
||||
|
||||
pred_start_index = input.size(-1)
|
||||
past_key_values = None
|
||||
done = [False]
|
||||
results = [None]
|
||||
for i in range(max_length):
|
||||
if i == 0:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input,
|
||||
context=context,
|
||||
length=length,
|
||||
span=span,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
logits, _, past_key_values = self.model.inference(
|
||||
input=input[:, -1:],
|
||||
context=context,
|
||||
length=length,
|
||||
span=span,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
|
||||
logits = NBCE(logits) # [vocab_size]
|
||||
logits = logits[None]
|
||||
|
||||
if i == 0:
|
||||
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
||||
|
||||
apply_repetition_penalty(
|
||||
logits,
|
||||
1,
|
||||
1,
|
||||
input,
|
||||
repetition_penalty,
|
||||
pred_start_index,
|
||||
input.size(-1) - 1,
|
||||
repetition_window,
|
||||
)
|
||||
|
||||
logits = logits / temperature
|
||||
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
for idx in range(1):
|
||||
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
|
||||
done[idx] = True
|
||||
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
|
||||
|
||||
if sum(done) == 1:
|
||||
break
|
||||
next_token = next_token.tile(batch_size, 1)
|
||||
# update input ids
|
||||
input = torch.cat([input, next_token], dim=-1)
|
||||
length = past_key_values["buffer_length"]
|
||||
length = torch.cat(
|
||||
[length, torch.ones((length.size(0), 1), dtype=torch.int32, device=length.device)],
|
||||
dim=-1,
|
||||
)
|
||||
# length += 1
|
||||
context = torch.cat(
|
||||
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
|
||||
dim=-1,
|
||||
)
|
||||
span = torch.cat(
|
||||
[span, torch.zeros((span.size(0), 1), dtype=torch.int32, device=span.device)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
result_text = list(map(self.tokenizer.decode, results))
|
||||
return result_text
|
|
@ -1,3 +0,0 @@
|
|||
from .cpm9g import CPM9G
|
||||
from .cpm9g import CPM9GConfig
|
||||
from .cpm9g_torch import CPM9GTorch
|
|
@ -1,272 +0,0 @@
|
|||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ...layers import Embedding
|
||||
from ...layers import Encoder
|
||||
from ...layers import RotaryEmbeddingESM
|
||||
from ...utils import Config
|
||||
from ...utils import gradient_shrink
|
||||
|
||||
|
||||
class CPM9GInferenceState(TypedDict):
|
||||
buffer_context: torch.Tensor
|
||||
buffer_sample_ids: torch.Tensor
|
||||
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
|
||||
class CPM9GConfig(Config):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
dim_model=4096,
|
||||
num_heads=32,
|
||||
num_kv_heads=32,
|
||||
dim_head=128,
|
||||
dim_ff=11008,
|
||||
num_layers=32,
|
||||
dropout_p=0.0,
|
||||
activate_fn="silu",
|
||||
scale=True,
|
||||
eps=1e-5,
|
||||
half: bool = True,
|
||||
bf16: bool = False,
|
||||
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
||||
use_flash_attn: bool = True,
|
||||
flash_attn_mask_shape="1d",
|
||||
flash_impl="cuda",
|
||||
base=10000,
|
||||
tp=0,
|
||||
disabled_checkpoint=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim_model = dim_model
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.dim_head = dim_head
|
||||
self.dim_ff = dim_ff
|
||||
self.num_layers = num_layers
|
||||
self.dropout_p = dropout_p
|
||||
self.activate_fn = activate_fn
|
||||
self.scale = scale
|
||||
self.eps = eps
|
||||
if half:
|
||||
if bf16:
|
||||
self.dtype = torch.bfloat16
|
||||
else:
|
||||
self.dtype = torch.half
|
||||
else:
|
||||
self.dtype = torch.float
|
||||
self.flash_impl = flash_impl
|
||||
self.mask_modules = mask_modules
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.flash_attn_mask_shape = flash_attn_mask_shape
|
||||
self.base = base
|
||||
self.tp = tp
|
||||
self.disabled_checkpoint = disabled_checkpoint
|
||||
|
||||
|
||||
class CPM9G(bmt.DistributedModule):
|
||||
def __init__(self, config: CPM9GConfig):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
num_layers=config.num_layers,
|
||||
dim_model=config.dim_model,
|
||||
dim_ff=config.dim_ff,
|
||||
num_heads=config.num_heads,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
dim_head=config.dim_head,
|
||||
activate_fn=config.activate_fn,
|
||||
dtype=config.dtype,
|
||||
eps=config.eps,
|
||||
dropout_p=config.dropout_p,
|
||||
scale=config.scale,
|
||||
mask_modules=config.mask_modules,
|
||||
use_flash_attn=config.use_flash_attn,
|
||||
tp=config.tp,
|
||||
disabled_checkpoint=config.disabled_checkpoint,
|
||||
)
|
||||
|
||||
self.input_embedding = Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=config.dim_model,
|
||||
scale=config.scale,
|
||||
dtype=config.dtype,
|
||||
init_std=0.02,
|
||||
)
|
||||
|
||||
self.position_bias = RotaryEmbeddingESM(
|
||||
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
|
||||
)
|
||||
|
||||
self.lm_head = Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=config.dim_model,
|
||||
scale=config.scale,
|
||||
dtype=config.dtype,
|
||||
init_std=0.02,
|
||||
)
|
||||
self.flash_impl = config.flash_impl
|
||||
self.use_flash_attn = config.use_flash_attn
|
||||
self.flash_attn_mask_shape = config.flash_attn_mask_shape
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor, # (batch, seqlen) int32
|
||||
length: torch.Tensor = None, # (batch) int32
|
||||
context: torch.Tensor = None, # (batch, seqlen) bool
|
||||
span: torch.Tensor = None, # (batch, seqlen) int32
|
||||
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
|
||||
max_seqlen: int = None,
|
||||
position_ids: torch.Tensor = None, # (batch, seqlen) int32
|
||||
):
|
||||
batch = input.size(0)
|
||||
seqlen = input.size(1)
|
||||
device = input.device
|
||||
|
||||
if length is not None and length.dim() == 1:
|
||||
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
||||
|
||||
# processing masks and position bias bucket
|
||||
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
|
||||
with torch.no_grad():
|
||||
# directional mask
|
||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
|
||||
-1, 1
|
||||
)
|
||||
# context mask
|
||||
attention_mask = context[:, None, :] | (
|
||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||
)
|
||||
# span mask
|
||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||
# length mask
|
||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
hidden_states = self.input_embedding(input)
|
||||
|
||||
if self.config.tp:
|
||||
with torch.no_grad():
|
||||
if length is not None:
|
||||
length = bmt.distributed.all_gather(length, comm=bmt.config["tp_comm"]).flatten(0, 1)
|
||||
if attention_mask is not None:
|
||||
attention_mask = bmt.distributed.all_gather(attention_mask, comm=bmt.config["tp_comm"]).flatten(
|
||||
0, 1
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
lens = bmt.distributed.all_gather(
|
||||
torch.tensor([cu_seqlens.numel()]).cuda(), comm=bmt.config["tp_comm"]
|
||||
)
|
||||
mx = lens.max().item()
|
||||
cu_seq = torch.zeros(mx).int().cuda()
|
||||
cu_seq[: cu_seqlens.numel()] = cu_seqlens
|
||||
all_seq = bmt.distributed.all_gather(cu_seq, comm=bmt.config["tp_comm"])
|
||||
cu_seqlens = [0]
|
||||
for i, l in enumerate(lens):
|
||||
for j in range(1, l - 1):
|
||||
cu_seqlens.append(all_seq[i][j] + i * seqlen)
|
||||
cu_seqlens.append((i + 1) * seqlen)
|
||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).cuda()
|
||||
if max_seqlen is not None:
|
||||
max_seqlen = bmt.distributed.all_reduce(
|
||||
torch.tensor([max_seqlen]).cuda(), "max", comm=bmt.config["tp_comm"]
|
||||
)[0].item()
|
||||
if position_ids is not None:
|
||||
position_ids = bmt.distributed.all_gather(position_ids, comm=bmt.config["tp_comm"]).flatten(0, 1)
|
||||
|
||||
if self.use_flash_attn:
|
||||
if self.flash_attn_mask_shape == "1d":
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=self.position_bias,
|
||||
pos_bias_type="rotary",
|
||||
length_mask=length,
|
||||
)
|
||||
else:
|
||||
if self.flash_impl == "triton":
|
||||
mask = attention_mask.unsqueeze(dim=1).contiguous()
|
||||
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
|
||||
attention_mask_bias[mask == False] -= torch.inf
|
||||
else:
|
||||
attention_mask_bias = None
|
||||
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=self.position_bias,
|
||||
pos_bias_type="rotary",
|
||||
length_mask=None,
|
||||
attention_mask_bias=attention_mask_bias,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.encoder(
|
||||
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
|
||||
)
|
||||
|
||||
logits = self.lm_head.projection(hidden_states)
|
||||
|
||||
return logits, hidden_states
|
||||
|
||||
def inference(
|
||||
self,
|
||||
input: torch.Tensor, # (batch, len_q) int32
|
||||
length: torch.Tensor, # (batch) int32
|
||||
context: torch.Tensor, # (batch, seqlen) int16
|
||||
span: torch.Tensor, # (batch, seqlen) int32
|
||||
past_key_values: Optional[CPM9GInferenceState] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
|
||||
batch = input.size(0)
|
||||
len_q = input.size(1)
|
||||
len_buffer = 0
|
||||
if past_key_values is None:
|
||||
present_buffer = None
|
||||
else:
|
||||
present_buffer = past_key_values["buffer"]
|
||||
len_buffer = present_buffer[0][0].shape[-2]
|
||||
seqlen = len_buffer + len_q
|
||||
with torch.no_grad():
|
||||
device = input.device
|
||||
if length.dim() == 1:
|
||||
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
||||
# context mask
|
||||
attention_mask = context[:, None, :] | (
|
||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||
)
|
||||
# span mask
|
||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||
# length mask
|
||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||
|
||||
hidden_states = self.input_embedding(input)
|
||||
|
||||
hidden_states, present_key_values, _ = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask[:, len_buffer:],
|
||||
position_bias=self.position_bias,
|
||||
use_cache=True,
|
||||
past_key_values=present_buffer,
|
||||
pos_bias_type="rotary",
|
||||
)
|
||||
|
||||
logits = self.lm_head.projection(hidden_states)
|
||||
|
||||
return (
|
||||
logits,
|
||||
hidden_states,
|
||||
{"buffer": present_key_values},
|
||||
)
|
|
@ -1,186 +0,0 @@
|
|||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ...native_layers import Embedding
|
||||
from ...native_layers import Encoder
|
||||
from ...native_layers import RotaryEmbeddingESM
|
||||
from ...utils import Config
|
||||
from ...utils import gradient_shrink
|
||||
from .cpm9g import CPM9GConfig
|
||||
|
||||
|
||||
class CPM9GInferenceState(TypedDict):
|
||||
buffer_context: torch.Tensor
|
||||
buffer_sample_ids: torch.Tensor
|
||||
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
|
||||
class CPM9GTorch(torch.nn.Module):
|
||||
def __init__(self, config: CPM9GConfig):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
num_layers=config.num_layers,
|
||||
dim_model=config.dim_model,
|
||||
dim_ff=config.dim_ff,
|
||||
num_heads=config.num_heads,
|
||||
num_kv_heads=config.num_kv_heads,
|
||||
dim_head=config.dim_head,
|
||||
activate_fn=config.activate_fn,
|
||||
dtype=config.dtype,
|
||||
eps=config.eps,
|
||||
dropout_p=config.dropout_p,
|
||||
scale=config.scale,
|
||||
mask_modules=config.mask_modules,
|
||||
use_flash_attn=config.use_flash_attn,
|
||||
)
|
||||
|
||||
self.input_embedding = Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=config.dim_model,
|
||||
scale=config.scale,
|
||||
dtype=config.dtype,
|
||||
init_std=0.02,
|
||||
)
|
||||
|
||||
self.position_bias = RotaryEmbeddingESM(
|
||||
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
|
||||
)
|
||||
|
||||
self.lm_head = Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=config.dim_model,
|
||||
scale=config.scale,
|
||||
dtype=config.dtype,
|
||||
init_std=0.02,
|
||||
)
|
||||
self.flash_impl = False
|
||||
self.use_flash_attn = False
|
||||
self.flash_attn_mask_shape = "1d"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor, # (batch, seqlen) int32
|
||||
length: torch.Tensor = None, # (batch) int32
|
||||
context: torch.Tensor = None, # (batch, seqlen) bool
|
||||
span: torch.Tensor = None, # (batch, seqlen) int32
|
||||
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
|
||||
max_seqlen: int = None,
|
||||
position_ids: torch.Tensor = None, # (batch, seqlen) int32
|
||||
):
|
||||
batch = input.size(0)
|
||||
seqlen = input.size(1)
|
||||
device = input.device
|
||||
|
||||
if length is not None and length.dim() == 1:
|
||||
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
||||
|
||||
# processing masks and position bias bucket
|
||||
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
|
||||
with torch.no_grad():
|
||||
# directional mask
|
||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
|
||||
-1, 1
|
||||
)
|
||||
# context mask
|
||||
attention_mask = context[:, None, :] | (
|
||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||
)
|
||||
# span mask
|
||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||
# length mask
|
||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||
|
||||
hidden_states = self.input_embedding(input)
|
||||
|
||||
if self.use_flash_attn:
|
||||
if self.flash_attn_mask_shape == "1d":
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=self.position_bias,
|
||||
pos_bias_type="rotary",
|
||||
length_mask=length,
|
||||
)
|
||||
else:
|
||||
if self.flash_impl == "triton":
|
||||
mask = attention_mask.unsqueeze(dim=1).contiguous()
|
||||
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
|
||||
attention_mask_bias[mask == False] -= torch.inf
|
||||
else:
|
||||
attention_mask_bias = None
|
||||
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
|
||||
hidden_states = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=self.position_bias,
|
||||
pos_bias_type="rotary",
|
||||
length_mask=None,
|
||||
attention_mask_bias=attention_mask_bias,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.encoder(
|
||||
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
|
||||
)
|
||||
|
||||
logits = self.lm_head.projection(hidden_states)
|
||||
|
||||
return logits, hidden_states
|
||||
|
||||
def inference(
|
||||
self,
|
||||
input: torch.Tensor, # (batch, len_q) int32
|
||||
length: torch.Tensor, # (batch) int32
|
||||
context: torch.Tensor, # (batch, seqlen) int16
|
||||
span: torch.Tensor, # (batch, seqlen) int32
|
||||
past_key_values: Optional[CPM9GInferenceState] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
|
||||
batch = input.size(0)
|
||||
len_q = input.size(1)
|
||||
len_buffer = 0
|
||||
if past_key_values is None:
|
||||
present_buffer = None
|
||||
else:
|
||||
present_buffer = past_key_values["buffer"]
|
||||
len_buffer = present_buffer[0][0].shape[-2]
|
||||
seqlen = len_buffer + len_q
|
||||
with torch.no_grad():
|
||||
device = input.device
|
||||
if length.dim() == 1:
|
||||
length = (torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) + length[:, None]) >= seqlen
|
||||
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
||||
# context mask
|
||||
attention_mask = context[:, None, :] | (
|
||||
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||
)
|
||||
# span mask
|
||||
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||
# length mask
|
||||
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||
|
||||
hidden_states = self.input_embedding(input)
|
||||
|
||||
hidden_states, present_key_values, _ = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask[:, len_buffer:],
|
||||
position_bias=self.position_bias,
|
||||
use_cache=True,
|
||||
past_key_values=present_buffer,
|
||||
pos_bias_type="rotary",
|
||||
)
|
||||
|
||||
logits = self.lm_head.projection(hidden_states)
|
||||
|
||||
return (
|
||||
logits,
|
||||
hidden_states,
|
||||
{"buffer": present_key_values},
|
||||
)
|
|
@ -1 +0,0 @@
|
|||
from .cpm9g import CPM9GTokenizer
|
|
@ -1,2 +0,0 @@
|
|||
from .pretrain import MixedDataset
|
||||
from .finetune import FinetuneDataset
|
|
@ -1,87 +0,0 @@
|
|||
import bmtrain as bmt
|
||||
|
||||
from ...dataset import SimpleDataset
|
||||
from ..tokenizers import CPM9GTokenizer
|
||||
from .pretrain import _MixedDatasetBatchPacker
|
||||
from .pretrain import _MixedDatasetConfig
|
||||
from .pretrain import CPM9GBatch
|
||||
|
||||
|
||||
class FinetuneDataset:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: str,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
tokenizer: CPM9GTokenizer,
|
||||
unpad: bool = False,
|
||||
task_name: str = "task",
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
self._world_size = bmt.world_size()
|
||||
self._rank = bmt.rank()
|
||||
self._batch_size = batch_size
|
||||
self._unpad = unpad
|
||||
self._max_length = max_length
|
||||
|
||||
self._packer = _MixedDatasetBatchPacker(batch_size * self._world_size, max_length, tokenizer, unpad)
|
||||
self._drop_last = drop_last
|
||||
|
||||
ds = SimpleDataset(dataset_path, shuffle=False)
|
||||
self._ds_cfg: _MixedDatasetConfig = {
|
||||
"weight": 1.0,
|
||||
"path": dataset_path,
|
||||
"transforms": [],
|
||||
"task_name": task_name,
|
||||
"dataset_name": "finetune",
|
||||
"incontext_weight": [1.0],
|
||||
"lines": len(ds),
|
||||
"dataset": ds,
|
||||
}
|
||||
self._nlines = len(ds)
|
||||
self._nbytes = ds.get_bytes()
|
||||
|
||||
def __batch_iter(self):
|
||||
while True:
|
||||
try:
|
||||
batch = self._packer.add_data(self._ds_cfg)
|
||||
except EOFError:
|
||||
break
|
||||
if batch is None:
|
||||
continue
|
||||
yield batch
|
||||
if len(self._packer) > 0:
|
||||
batch = self._packer.pack_batch(force=True)
|
||||
if not self._drop_last:
|
||||
yield batch
|
||||
self._ds_cfg["dataset"]._repeat_times = 0
|
||||
|
||||
def __iter__(self):
|
||||
if not self._unpad:
|
||||
batch_st = self._batch_size * self._rank
|
||||
batch_end = self._batch_size * (self._rank + 1)
|
||||
for batch in self.__batch_iter():
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
if batch_size <= batch_st:
|
||||
yield None
|
||||
else:
|
||||
ret: CPM9GBatch = {
|
||||
kw: val[batch_st:batch_end] # type: ignore
|
||||
for kw, val in batch.items()
|
||||
if kw not in ["task_names", "raw_data", "cu_seqlens", "max_seqlen"]
|
||||
} # type: ignore
|
||||
ret["task_names"] = batch["task_names"]
|
||||
ret["raw_data"] = batch["raw_data"]
|
||||
ret["cu_seqlens"] = batch["cu_seqlens"]
|
||||
ret["max_seqlen"] = batch["max_seqlen"]
|
||||
yield ret
|
||||
else:
|
||||
for batch in self.__batch_iter():
|
||||
assert batch["inputs"].shape[0] == 1
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
iter_count = int(self._nbytes // (3.6 * self._batch_size * self._world_size * self._max_length))
|
||||
if not self._drop_last:
|
||||
iter_count += 1
|
||||
return iter_count
|
|
@ -1,736 +0,0 @@
|
|||
import importlib.machinery
|
||||
import importlib.util
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from ...dataset import DistributedDataset
|
||||
from ..tokenizers import CPM9GTokenizer
|
||||
|
||||
|
||||
class _MixedDatasetConfig(TypedDict):
|
||||
weight: float
|
||||
path: str
|
||||
transforms: Union[List[Dict[str, Any]], str]
|
||||
task_name: str
|
||||
dataset_name: str
|
||||
incontext_weight: List[float]
|
||||
lines: int
|
||||
dataset: DistributedDataset
|
||||
|
||||
|
||||
CPM9GInputType = Union[str, Dict[str, "CPM9GInputType"]]
|
||||
|
||||
|
||||
class _TransformFuncDict(TypedDict):
|
||||
loader: importlib.machinery.SourceFileLoader
|
||||
module: types.ModuleType
|
||||
last_m: float
|
||||
|
||||
|
||||
_TransformFunction = Callable[[CPM9GInputType, int, random.Random], CPM9GInputType]
|
||||
|
||||
|
||||
class CPM9GBatch(TypedDict):
|
||||
inputs: NDArray[np.int32]
|
||||
length: NDArray[np.int32]
|
||||
context: NDArray[np.bool_]
|
||||
sample_ids: NDArray[np.int32]
|
||||
spans: NDArray[np.int32]
|
||||
target: NDArray[np.int32]
|
||||
task_ids: NDArray[np.int32]
|
||||
task_names: List[str]
|
||||
raw_data: List[Any]
|
||||
|
||||
|
||||
def convert_data_to_id(tokenizer: CPM9GTokenizer, data: Any):
|
||||
input_ids = tokenizer.encode(data["input"])
|
||||
output_ids = tokenizer.encode(data["output"])
|
||||
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
|
||||
ids = np.array(ids, dtype=np.int32)
|
||||
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
||||
context[: len(input_ids) + 1] = 1
|
||||
return ids, context
|
||||
|
||||
|
||||
def _dataset_identity(c: _MixedDatasetConfig):
|
||||
return "{}.{}".format(c["task_name"], c["dataset_name"])
|
||||
|
||||
|
||||
class _MixedDatasetBatchPacker:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
tokenizer: CPM9GTokenizer,
|
||||
unpad: bool = False,
|
||||
) -> None:
|
||||
self._batch_size = batch_size
|
||||
self._max_length = max_length
|
||||
self._clip_length = max_length
|
||||
self.tokenizer = tokenizer
|
||||
self._transform_func_table: Dict[str, _TransformFuncDict] = {}
|
||||
self._unpad = unpad
|
||||
if unpad:
|
||||
self._max_length = max_length * batch_size
|
||||
self._batch_size = 1
|
||||
|
||||
self._inputs: List[NDArray[np.int32]] = []
|
||||
self._context: List[NDArray[np.int8]] = []
|
||||
self._sample_ids: List[NDArray[np.int32]] = []
|
||||
self._spans: List[List[int]] = []
|
||||
self._task_ids: List[List[str]] = []
|
||||
self._raw_data: List[List[Any]] = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self._inputs)
|
||||
|
||||
def apply_transform(
|
||||
self,
|
||||
data: CPM9GInputType,
|
||||
transform: Union[Dict[str, Any], Callable[[CPM9GInputType], CPM9GInputType], None],
|
||||
) -> CPM9GInputType:
|
||||
if transform is None:
|
||||
return data
|
||||
if not isinstance(transform, dict):
|
||||
return transform(data) # transform function
|
||||
|
||||
mapping_list: List[Tuple[str, str]] = []
|
||||
|
||||
def _walk_transform_dict(data: Union[Dict[str, Any], str], prefix: str = ""):
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
if len(prefix) > 0:
|
||||
_walk_transform_dict(v, prefix + "." + k)
|
||||
else:
|
||||
_walk_transform_dict(v, k)
|
||||
else:
|
||||
assert isinstance(data, str), "Invalid transform {}".format(data)
|
||||
mapping_list.append((prefix, data))
|
||||
|
||||
_walk_transform_dict(transform)
|
||||
|
||||
expanded_mapping_list: List[Tuple[str, Any]] = []
|
||||
|
||||
def _expand_mapping(data: CPM9GInputType, stars: List[str], path: List[str], target: List[str]):
|
||||
if len(path) == 0:
|
||||
num_stars = 0
|
||||
for it in target:
|
||||
if it == "*":
|
||||
num_stars += 1
|
||||
if num_stars != len(stars):
|
||||
raise ValueError("Invalid transform {}".format(".".join(target)))
|
||||
|
||||
nw_tgt = []
|
||||
num_stars = 0
|
||||
for it in target:
|
||||
if it == "*":
|
||||
nw_tgt.append(stars[num_stars])
|
||||
num_stars += 1
|
||||
else:
|
||||
nw_tgt.append(it)
|
||||
expanded_mapping_list.append((".".join(nw_tgt), data))
|
||||
else:
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid data {}".format(data))
|
||||
if path[0] == "*":
|
||||
for k, v in data.items():
|
||||
_expand_mapping(v, stars + [k], path[1:], target)
|
||||
else:
|
||||
_expand_mapping(data[path[0]], stars, path[1:], target)
|
||||
|
||||
# expand mapping list
|
||||
for tgt, src in mapping_list:
|
||||
if src.startswith("$"):
|
||||
# copy from src
|
||||
_expand_mapping(data, [], src[1:].split("."), tgt.split("."))
|
||||
else:
|
||||
if "*" in tgt:
|
||||
raise ValueError("Constant value is not allowed to have `*` in prefix")
|
||||
expanded_mapping_list.append((tgt, src))
|
||||
|
||||
ret = {}
|
||||
for tgt, val in expanded_mapping_list:
|
||||
tgt = tgt.split(".")
|
||||
cur = ret
|
||||
while len(tgt) > 1:
|
||||
cur = cur[tgt[0]]
|
||||
tgt = tgt[1:]
|
||||
cur[tgt[0]] = val
|
||||
return ret
|
||||
|
||||
def data_to_id(self, data: Any):
|
||||
return convert_data_to_id(self.tokenizer, data)
|
||||
|
||||
def _ensure_transform_function(self, module_name: str, transform_script_path: str) -> _TransformFunction:
|
||||
module_name = "cpm_live.transforms.{}".format(module_name)
|
||||
if transform_script_path not in self._transform_func_table:
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
if spec is None:
|
||||
raise RuntimeError("spec is none! {}".format(module_name))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self._transform_func_table[transform_script_path] = {
|
||||
"loader": loader,
|
||||
"module": mod,
|
||||
"last_m": 0,
|
||||
}
|
||||
|
||||
transform_script_info = self._transform_func_table[transform_script_path]
|
||||
curr_m_time = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||
if curr_m_time > transform_script_info["last_m"]:
|
||||
transform_script_info["last_m"] = curr_m_time
|
||||
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||
if transform_func is None:
|
||||
|
||||
def _empty_transform_func(data: CPM9GInputType, num_sample: int, r: random.Random):
|
||||
raise NotImplementedError("Transform func for dataset {} not implemented".format(module_name))
|
||||
|
||||
return _empty_transform_func
|
||||
else:
|
||||
return transform_func
|
||||
|
||||
def build_instance(self, config: _MixedDatasetConfig):
|
||||
_sample_weight = np.array(config["incontext_weight"], dtype=np.float32)
|
||||
_sample_weight = _sample_weight / _sample_weight.sum()
|
||||
num_incontext = np.random.choice(_sample_weight.shape[0], p=_sample_weight)
|
||||
ds = config["dataset"]
|
||||
transforms = config["transforms"]
|
||||
if isinstance(transforms, str):
|
||||
if not os.path.exists(transforms):
|
||||
raise RuntimeError("transform script {} file not exists".format(transforms))
|
||||
# load transform script
|
||||
transform_func = self._ensure_transform_function(_dataset_identity(config), transforms)
|
||||
seed = random.random()
|
||||
|
||||
def _transform(data: CPM9GInputType):
|
||||
r = random.Random(seed)
|
||||
return transform_func(data, num_incontext, r)
|
||||
|
||||
transform = _transform
|
||||
elif len(transforms) == 0:
|
||||
transform = None
|
||||
else:
|
||||
transform = transforms[np.random.choice(len(transforms))]
|
||||
|
||||
raw_data = {}
|
||||
while True:
|
||||
inp = ds.read()
|
||||
inp = self.apply_transform(inp, transform)
|
||||
# if None, skip this one
|
||||
if inp is None:
|
||||
continue
|
||||
input_ids, context = self.data_to_id(inp)
|
||||
if input_ids.shape[0] > self._clip_length:
|
||||
continue # too long
|
||||
input_ids = input_ids[: self._clip_length]
|
||||
context = context[: self._clip_length]
|
||||
raw_data["input"] = inp
|
||||
raw_data["samples"] = []
|
||||
break
|
||||
|
||||
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
||||
i = 0
|
||||
while True:
|
||||
if i == num_incontext or input_ids.shape[0] >= self._clip_length:
|
||||
break # early break
|
||||
sample = ds.read()
|
||||
sample = self.apply_transform(sample, transform)
|
||||
# if sample is None, skip this one
|
||||
if sample is None:
|
||||
continue
|
||||
sample_input_ids, _ = self.data_to_id(sample)
|
||||
if input_ids.shape[0] + sample_input_ids.shape[0] > self._clip_length:
|
||||
break # too long, break
|
||||
raw_data["samples"].append(sample)
|
||||
input_ids = np.concatenate([input_ids, sample_input_ids], axis=0)
|
||||
context = np.concatenate([context, np.ones(sample_input_ids.shape, dtype=np.int8)], axis=0)
|
||||
sample_ids = np.concatenate([sample_ids, np.full(sample_input_ids.shape, i + 1, dtype=np.int32)], axis=0)
|
||||
|
||||
i += 1
|
||||
|
||||
return (
|
||||
input_ids,
|
||||
context,
|
||||
sample_ids,
|
||||
raw_data,
|
||||
)
|
||||
|
||||
def pack_batch(self, force: bool = False, unilm=False) -> CPM9GBatch:
|
||||
# pack batch
|
||||
if len(self._inputs) < self._batch_size:
|
||||
if not force:
|
||||
raise RuntimeError("Batch insufficient")
|
||||
batch_size = len(self._inputs)
|
||||
else:
|
||||
batch_size = self._batch_size
|
||||
max_length = self._max_length # self._spans[0][-1] if self._unpad else self._max_length
|
||||
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
||||
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
||||
|
||||
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
length = np.zeros((batch_size,), dtype=np.int32)
|
||||
task_ids = np.full((batch_size, max_length), -1, dtype=np.int32)
|
||||
if self._spans[0][-1] != max_length:
|
||||
cu_seqlens = np.array([0] + self._spans[0] + [max_length], dtype=np.int32)
|
||||
else:
|
||||
cu_seqlens = np.array([0] + self._spans[0], dtype=np.int32)
|
||||
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
|
||||
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
|
||||
all_task_names: Set[str] = set()
|
||||
for i in range(batch_size):
|
||||
for task_name in self._task_ids[i]:
|
||||
all_task_names.add(task_name)
|
||||
task_names: List[str] = list(all_task_names)
|
||||
task_name_to_id = {name: i for i, name in enumerate(task_names)}
|
||||
|
||||
raw_data_list: List[Any] = []
|
||||
for i in range(batch_size):
|
||||
instance_length = self._inputs[i].shape[0]
|
||||
inputs[i, :instance_length] = self._inputs[i]
|
||||
sample_ids[i, :instance_length] = self._sample_ids[i]
|
||||
if unilm:
|
||||
context[i, :instance_length] = self._context[i]
|
||||
|
||||
span_begin = 0
|
||||
for span_id, (span_end, task_name) in enumerate(zip(self._spans[i], self._task_ids[i])):
|
||||
spans[i, span_begin:span_end] = span_id
|
||||
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
|
||||
task_ids[i, span_begin:span_end] = task_name_to_id[task_name]
|
||||
span_begin = span_end
|
||||
length[i] = instance_length
|
||||
raw_data_list.extend(self._raw_data[i])
|
||||
|
||||
for j in range(instance_length):
|
||||
if j > 1 and self._context[i][j] == 0:
|
||||
if self._inputs[i][j] != self.tokenizer.bos_id and self._inputs[i][j - 1] != self.tokenizer.eos_id:
|
||||
tgt[i, j - 1] = self._inputs[i][j]
|
||||
|
||||
self._inputs = self._inputs[batch_size:]
|
||||
self._context = self._context[batch_size:]
|
||||
self._sample_ids = self._sample_ids[batch_size:]
|
||||
self._spans = self._spans[batch_size:]
|
||||
self._task_ids = self._task_ids[batch_size:]
|
||||
self._raw_data = self._raw_data[batch_size:]
|
||||
return {
|
||||
"inputs": inputs,
|
||||
"length": length,
|
||||
"context": context > 0,
|
||||
"sample_ids": sample_ids,
|
||||
"spans": spans,
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": max_seqlen,
|
||||
"position_ids": position_ids,
|
||||
"target": tgt,
|
||||
"task_ids": task_ids,
|
||||
"task_names": task_names,
|
||||
"raw_data": raw_data_list,
|
||||
}
|
||||
|
||||
def add_data(self, config: _MixedDatasetConfig) -> Optional[CPM9GBatch]:
|
||||
(
|
||||
input_ids,
|
||||
context,
|
||||
sample_ids,
|
||||
raw_data,
|
||||
) = self.build_instance(config)
|
||||
|
||||
# add to batch
|
||||
best_fit: Union[None, int] = None
|
||||
best_fit_space: Union[None, int] = None
|
||||
for i in range(len(self._inputs)):
|
||||
space = self._max_length - self._inputs[i].shape[0]
|
||||
if input_ids.shape[0] <= space:
|
||||
if best_fit_space is None:
|
||||
best_fit = i
|
||||
best_fit_space = space
|
||||
elif best_fit_space > space:
|
||||
best_fit = i
|
||||
best_fit_space = space
|
||||
if best_fit is None:
|
||||
# add a new instance
|
||||
self._inputs.append(input_ids)
|
||||
self._context.append(context)
|
||||
self._sample_ids.append(sample_ids)
|
||||
self._spans.append([input_ids.shape[0]])
|
||||
self._task_ids.append([config["task_name"]])
|
||||
self._raw_data.append([raw_data])
|
||||
else:
|
||||
# add to existing instance
|
||||
self._inputs[best_fit] = np.concatenate([self._inputs[best_fit], input_ids], axis=0)
|
||||
self._context[best_fit] = np.concatenate([self._context[best_fit], context], axis=0)
|
||||
self._sample_ids[best_fit] = np.concatenate([self._sample_ids[best_fit], sample_ids], axis=0)
|
||||
self._spans[best_fit].append(self._inputs[best_fit].shape[0])
|
||||
self._task_ids[best_fit].append(config["task_name"])
|
||||
self._raw_data[best_fit].append(raw_data)
|
||||
|
||||
if len(self._inputs) > self._batch_size:
|
||||
return self.pack_batch()
|
||||
else:
|
||||
return None # not ready
|
||||
|
||||
|
||||
class _MixedDatasetConfigMananger:
|
||||
def __init__(self, config_path: str) -> None:
|
||||
self._config_path: str = config_path
|
||||
self._config: Union[List[_MixedDatasetConfig], None] = None
|
||||
self._last_m = 0
|
||||
|
||||
def changed(self):
|
||||
while True:
|
||||
try:
|
||||
m_time = os.stat(self._config_path).st_mtime
|
||||
if m_time > self._last_m:
|
||||
# try to load new config
|
||||
try:
|
||||
self._config = json.load(open(self._config_path, "r", encoding="utf-8"))
|
||||
except Exception:
|
||||
# failed to load config
|
||||
return False
|
||||
# new config loaded
|
||||
self._last_m = m_time
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
print("Error: reading info list! {}".format(self._config_path))
|
||||
time.sleep(30)
|
||||
|
||||
def get_config(self) -> List[_MixedDatasetConfig]:
|
||||
if self._config is None:
|
||||
if not self.changed():
|
||||
raise RuntimeError("Failed to load config")
|
||||
if self._config is None:
|
||||
raise RuntimeError("Failed to load config")
|
||||
return self._config
|
||||
|
||||
|
||||
def _mixed_dataset_process(
|
||||
config_path: str,
|
||||
q_cmd: multiprocessing.Queue,
|
||||
q_cmd_out: multiprocessing.Queue,
|
||||
q_data: multiprocessing.Queue,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
packer: _MixedDatasetBatchPacker,
|
||||
max_repeat_times: int = None,
|
||||
random_state=None,
|
||||
):
|
||||
for _ in range(rank):
|
||||
np.random.random()
|
||||
random.setstate(random_state)
|
||||
# ignore SIGINT
|
||||
import signal
|
||||
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
config_base_path = os.path.dirname(os.path.abspath(config_path))
|
||||
|
||||
# fix libgomp threading deadlock after fork(), use single-thread in data process.
|
||||
# REF: https://github.com/pytorch/pytorch/issues/17199
|
||||
torch.set_num_threads(1)
|
||||
|
||||
def _convert_to_abs_path(transform_path: str):
|
||||
if transform_path.startswith("/"):
|
||||
return transform_path
|
||||
else:
|
||||
return os.path.join(config_base_path, transform_path)
|
||||
|
||||
def _build_sample_weights(config: List[_MixedDatasetConfig]):
|
||||
if len(config) == 0:
|
||||
return np.array([], dtype=np.float32)
|
||||
weights = [c["weight"] * c["lines"] for c in config]
|
||||
weights = np.array(weights, dtype=np.float32)
|
||||
sm_weight = weights.sum()
|
||||
if sm_weight > 0:
|
||||
weights = weights / sm_weight
|
||||
return weights
|
||||
else:
|
||||
raise RuntimeError("Empty datasets")
|
||||
|
||||
cfg_mgr = _MixedDatasetConfigMananger(config_path)
|
||||
config = cfg_mgr.get_config()
|
||||
|
||||
for c in config:
|
||||
ds = DistributedDataset(
|
||||
_convert_to_abs_path(c["path"]),
|
||||
rank,
|
||||
world_size,
|
||||
max_repeat_times=max_repeat_times,
|
||||
)
|
||||
|
||||
c["lines"] = ds._nlines
|
||||
c["dataset"] = ds
|
||||
if "weight" not in c:
|
||||
c["weight"] = 1.0
|
||||
if "transforms" not in c:
|
||||
c["transforms"] = []
|
||||
elif isinstance(c["transforms"], str):
|
||||
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
||||
if "incontext_weight" not in c:
|
||||
c["incontext_weight"] = [1.0]
|
||||
|
||||
weights = _build_sample_weights(config)
|
||||
|
||||
should_stop = False
|
||||
should_start = False
|
||||
|
||||
while not should_stop:
|
||||
# update config first
|
||||
if cfg_mgr.changed():
|
||||
path_ds_map: Dict[str, _MixedDatasetConfig] = {}
|
||||
nw_path_set: Set[str] = set()
|
||||
|
||||
# load new config
|
||||
nw_config = cfg_mgr.get_config()
|
||||
|
||||
# build path -> dataset map
|
||||
for c in config:
|
||||
path_ds_map[_dataset_identity(c)] = c
|
||||
|
||||
# add new datasets
|
||||
for c in nw_config:
|
||||
if _dataset_identity(c) in path_ds_map:
|
||||
# update values only
|
||||
if "weight" in c:
|
||||
path_ds_map[_dataset_identity(c)]["weight"] = c["weight"]
|
||||
if "transform" in c:
|
||||
if isinstance(c["transforms"], str):
|
||||
path_ds_map[_dataset_identity(c)]["transforms"] = _convert_to_abs_path(c["transforms"])
|
||||
else:
|
||||
path_ds_map[_dataset_identity(c)]["transforms"] = c["transforms"]
|
||||
if "incontext_weight" in c:
|
||||
path_ds_map[_dataset_identity(c)]["incontext_weight"] = c["incontext_weight"]
|
||||
else:
|
||||
# new dataset
|
||||
ds = DistributedDataset(
|
||||
_convert_to_abs_path(c["path"]),
|
||||
rank,
|
||||
world_size,
|
||||
max_repeat_times=max_repeat_times,
|
||||
)
|
||||
c["lines"] = ds._nlines
|
||||
c["dataset"] = ds
|
||||
if "weight" not in c:
|
||||
c["weight"] = 1.0
|
||||
if "transforms" not in c:
|
||||
c["transforms"] = []
|
||||
elif isinstance(c["transforms"], str):
|
||||
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
||||
if "incontext_weight" not in c:
|
||||
c["incontext_weight"] = [1.0]
|
||||
path_ds_map[_dataset_identity(c)] = c
|
||||
nw_path_set.add(_dataset_identity(c))
|
||||
|
||||
# remove unused datasets
|
||||
for c in config:
|
||||
if _dataset_identity(c) not in nw_path_set:
|
||||
del path_ds_map[_dataset_identity(c)]
|
||||
|
||||
config: List[_MixedDatasetConfig] = []
|
||||
for c in nw_config:
|
||||
config.append(path_ds_map[_dataset_identity(c)])
|
||||
del path_ds_map
|
||||
del nw_path_set
|
||||
del nw_config
|
||||
|
||||
weights = _build_sample_weights(config)
|
||||
|
||||
# get cmds
|
||||
while True:
|
||||
try:
|
||||
cmd = q_cmd.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
if cmd == "stop":
|
||||
should_stop = True
|
||||
q_cmd_out.put(True)
|
||||
break
|
||||
elif cmd == "state_dict":
|
||||
ret = OrderedDict()
|
||||
for c in config:
|
||||
ds_name = _dataset_identity(c)
|
||||
ret[ds_name] = c["dataset"]._state_dict()
|
||||
q_cmd_out.put(ret)
|
||||
elif cmd == "load_state_dict":
|
||||
state_dict = q_cmd.get()
|
||||
missing = []
|
||||
for idx, c in enumerate(config):
|
||||
ds_name = _dataset_identity(c)
|
||||
print(f"loading {idx}/{len(config)} {ds_name}")
|
||||
if ds_name in state_dict:
|
||||
c["dataset"].load_state_dict(state_dict[ds_name], strict=False)
|
||||
else:
|
||||
# new dataset
|
||||
missing.append(ds_name)
|
||||
q_cmd_out.put(missing)
|
||||
elif cmd == "start":
|
||||
should_start = True
|
||||
q_cmd_out.put(True)
|
||||
else:
|
||||
raise RuntimeError("Unknown command: {}".format(cmd))
|
||||
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
if not should_start:
|
||||
# wait for start cmd
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if len(config) == 0:
|
||||
# no dataset available
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if q_data.full():
|
||||
# queue full
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# sample a dataset
|
||||
ds_id: int = 0
|
||||
|
||||
while True:
|
||||
ds_id = np.random.choice(weights.shape[0], p=weights)
|
||||
if config[ds_id]["dataset"]._nlines != config[ds_id]["lines"]:
|
||||
# dataset size changed
|
||||
for c in config:
|
||||
c["lines"] = c["dataset"]._nlines
|
||||
weights = _build_sample_weights(config)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
batch = packer.add_data(config[ds_id])
|
||||
if batch is not None:
|
||||
# new batch comming
|
||||
q_data.put(batch)
|
||||
|
||||
# clean queue
|
||||
while True:
|
||||
try:
|
||||
q_data.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
|
||||
class MixedDataset:
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
tokenizer: CPM9GTokenizer,
|
||||
unpad: bool = False,
|
||||
max_repeat_times: int = None,
|
||||
) -> None:
|
||||
self._q_cmd = multiprocessing.Queue()
|
||||
self._q_cmd_out = multiprocessing.Queue()
|
||||
self._q_data = multiprocessing.Queue(maxsize=2)
|
||||
self._packer = _MixedDatasetBatchPacker(batch_size, max_length, tokenizer, unpad)
|
||||
self._p = multiprocessing.Process(
|
||||
target=_mixed_dataset_process,
|
||||
args=(
|
||||
config_path,
|
||||
self._q_cmd,
|
||||
self._q_cmd_out,
|
||||
self._q_data,
|
||||
bmt.rank(),
|
||||
bmt.world_size(),
|
||||
self._packer,
|
||||
max_repeat_times,
|
||||
random.getstate(),
|
||||
),
|
||||
)
|
||||
self._p.start()
|
||||
self._closed = False
|
||||
|
||||
def close(self):
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self._q_cmd.put("stop")
|
||||
assert self._q_cmd_out.get(), "Failed to stop process"
|
||||
self._p.join()
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self._closed
|
||||
|
||||
def start(self):
|
||||
self._q_cmd.put("start")
|
||||
return self._q_cmd_out.get()
|
||||
|
||||
def state_dict(self):
|
||||
self._q_cmd.put("state_dict")
|
||||
states = self._q_cmd_out.get()
|
||||
if not isinstance(states, OrderedDict):
|
||||
raise RuntimeError("Invalid state dict {}".format(states))
|
||||
if bmt.world_size() == 1:
|
||||
for val in states.values():
|
||||
val["states"].unsqueeze_(0)
|
||||
val["block"].unsqueeze_(0)
|
||||
return states
|
||||
|
||||
ret = OrderedDict()
|
||||
for k, v in states.items():
|
||||
num_unused_block = v["states"].size(0)
|
||||
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
||||
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
|
||||
if max_unused_blocks == 0:
|
||||
max_unused_blocks = 1
|
||||
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
||||
gpu_states[:num_unused_block] = v["states"].cuda()
|
||||
|
||||
gpu_block = v["block"].cuda()
|
||||
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
|
||||
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
|
||||
ret[k] = {"states": global_states, "block": global_block}
|
||||
return ret
|
||||
|
||||
def load_state_dict(self, data: OrderedDict, strict: bool = False):
|
||||
self._q_cmd.put("load_state_dict")
|
||||
self._q_cmd.put(data)
|
||||
missing = self._q_cmd_out.get()
|
||||
if strict:
|
||||
if len(missing) > 0:
|
||||
raise RuntimeError("Missing dataset state: {}".format(missing))
|
||||
return missing
|
||||
|
||||
def get(self) -> CPM9GBatch:
|
||||
ret: CPM9GBatch = self._q_data.get(timeout=300) # type: ignore
|
||||
if not isinstance(ret, dict):
|
||||
raise RuntimeError("Invalid data {}".format(ret))
|
||||
return ret
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield self.get()
|
||||
|
||||
def __del__(self):
|
||||
if not self.closed:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
File diff suppressed because it is too large
Load Diff
|
@ -1,57 +0,0 @@
|
|||
import itertools
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
|
||||
class ListDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
同时支持 map-style 和 iterable-style
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, data_list: List, distributed: bool = False, shuffle: bool = True, infinite: bool = False
|
||||
) -> None:
|
||||
super(ListDataset, self).__init__()
|
||||
if distributed:
|
||||
rank = bmt.rank()
|
||||
world_size = bmt.world_size()
|
||||
self.data_list = list(itertools.islice(data_list, rank, None, world_size))
|
||||
else:
|
||||
self.data_list = data_list
|
||||
self.shuffle = shuffle
|
||||
self.infinite = infinite
|
||||
self.idx = 0
|
||||
|
||||
if shuffle:
|
||||
self._shuffle()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.idx >= len(self):
|
||||
if self.infinite:
|
||||
if self.shuffle:
|
||||
self._shuffle()
|
||||
self.idx = 0
|
||||
else:
|
||||
raise StopIteration
|
||||
data = self.data_list[self.idx]
|
||||
self.idx += 1
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data_list[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_list)
|
||||
|
||||
def _shuffle(self):
|
||||
random.shuffle(self.data_list)
|
||||
|
||||
def read(self):
|
||||
return self.__next__()
|
|
@ -1,4 +0,0 @@
|
|||
from .generation_utils import apply_repetition_penalty
|
||||
from .generation_utils import BeamHypotheses
|
||||
from .generation_utils import NBCE
|
||||
from .generation_utils import top_k_top_p_filtering
|
|
@ -1,127 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
batch_size = logits.size()[0]
|
||||
if top_p > 0.0:
|
||||
logits = logits.view(batch_size, -1).contiguous()
|
||||
for index in range(len(logits)):
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[index][indices_to_remove] = filter_value
|
||||
|
||||
logits = logits.view(batch_size, -1).contiguous()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_repetition_penalty(
|
||||
logits,
|
||||
batch_size,
|
||||
num_beams,
|
||||
prev_output_tokens,
|
||||
repetition_penalty,
|
||||
start_idx=None,
|
||||
end_idx=None,
|
||||
window_size=None,
|
||||
):
|
||||
# only conduct repetition penalty for the output
|
||||
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
for i in range(batch_size * num_beams):
|
||||
if start_idx is None or end_idx is None:
|
||||
output_tokens = prev_output_tokens[i].tolist()
|
||||
else:
|
||||
if end_idx >= start_idx:
|
||||
if window_size:
|
||||
output_tokens = prev_output_tokens[i][
|
||||
max(start_idx, end_idx + 1 - window_size) : end_idx + 1
|
||||
].tolist()
|
||||
else:
|
||||
output_tokens = prev_output_tokens[i][start_idx : end_idx + 1].tolist()
|
||||
else:
|
||||
output_tokens = []
|
||||
for previous_token in set(output_tokens):
|
||||
# if score < 0 then repetition penalty has to
|
||||
# multiplied to reduce the previous token probability
|
||||
if logits[i, previous_token] < 0:
|
||||
logits[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
logits[i, previous_token] /= repetition_penalty
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_len = max_len
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.n_hyp = n_hyp
|
||||
self.hyp = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.hyp)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
|
||||
if len(self) < self.n_hyp or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.n_hyp:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs, cur_len):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if len(self) < self.n_hyp:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
|
||||
|
||||
|
||||
def NBCE(logits):
|
||||
"""
|
||||
Naive Bayes-based Context Extension
|
||||
blog: https://www.kexue.fm/archives/9617
|
||||
"""
|
||||
beta = 0.25
|
||||
logits = logits[:, -1] # bsh -> bh
|
||||
logits = logits - logits.logsumexp(dim=-1, keepdims=True)
|
||||
k = (logits.exp() * logits).sum(dim=-1)[1:].argmax() + 1
|
||||
logits_max = logits[k]
|
||||
logits_uncond = logits[0]
|
||||
logits = (1 + beta) * logits_max - beta * logits_uncond
|
||||
return logits
|
|
@ -1,12 +0,0 @@
|
|||
from .attention import Attention
|
||||
from .blocks import TransformerBlock
|
||||
from .embedding import Embedding
|
||||
from .embedding import EmbeddingExt
|
||||
from .feedforward import FeedForward
|
||||
from .layernorm import LayerNorm
|
||||
from .linear import Linear
|
||||
from .position_embedding import BucketPositionBias
|
||||
from .position_embedding import RotaryEmbedding
|
||||
from .position_embedding import RotaryEmbeddingESM
|
||||
from .position_embedding import SegmentPositionEmbedding
|
||||
from .transformer import Encoder
|
|
@ -1,134 +0,0 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
dim_head: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
dropout_p: Optional[float] = None,
|
||||
use_flash_attn: bool = False,
|
||||
scale: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = dim_model
|
||||
self.num_heads = num_heads
|
||||
self.dim_head = dim_head
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_groups = num_heads // num_kv_heads
|
||||
|
||||
self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype, scale=scale)
|
||||
self.project_k = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
|
||||
self.project_v = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
|
||||
|
||||
self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale)
|
||||
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
if dropout_p is not None:
|
||||
self.dropout = torch.nn.Dropout(p=dropout_p)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
# if use_flash_attn:
|
||||
# self.core_attention_flash = FlashSelfAttention(causal=False, attention_dropout=0.0)
|
||||
# self.use_flash_attn = use_flash_attn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_q: torch.Tensor,
|
||||
hidden_kv: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor,
|
||||
position_bias: torch.Tensor,
|
||||
use_cache: bool = False,
|
||||
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
|
||||
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
|
||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
|
||||
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
|
||||
Return:
|
||||
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
|
||||
""" # noqa: E501
|
||||
|
||||
batch_size = hidden_q.size(0)
|
||||
len_q = hidden_q.size(1)
|
||||
len_k = hidden_kv.size(1)
|
||||
|
||||
h_q = self.project_q(hidden_q)
|
||||
h_k = self.project_k(hidden_kv)
|
||||
h_v = self.project_v(hidden_kv)
|
||||
|
||||
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
|
||||
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
|
||||
|
||||
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||
|
||||
if pos_bias_type == "rotary":
|
||||
# b h s d
|
||||
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
|
||||
|
||||
if past_kv is not None:
|
||||
h_k = torch.cat([past_kv[0], h_k], dim=-2)
|
||||
h_v = torch.cat([past_kv[1], h_v], dim=-2)
|
||||
len_k = h_k.size(-2)
|
||||
|
||||
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
|
||||
if self.head_groups == 1:
|
||||
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
|
||||
else:
|
||||
score = torch.matmul(
|
||||
h_q.reshape(batch_size, self.num_kv_heads, self.head_groups * len_q, self.dim_head),
|
||||
h_k.transpose(-1, -2),
|
||||
).view(batch_size, self.num_heads, len_q, len_k)
|
||||
|
||||
if pos_bias_type == "relative":
|
||||
score = score + position_bias
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
||||
)
|
||||
|
||||
score = self.softmax(score)
|
||||
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
||||
)
|
||||
|
||||
if self.dropout is not None:
|
||||
score = self.dropout(score)
|
||||
|
||||
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
|
||||
score = torch.matmul(score.view(batch_size, self.num_kv_heads, self.head_groups * len_q, len_k), h_v).view(
|
||||
batch_size, self.num_heads, len_q, self.dim_head
|
||||
)
|
||||
|
||||
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
||||
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
||||
|
||||
score = self.attention_out(score)
|
||||
if use_cache:
|
||||
return score, (h_k, h_v)
|
||||
else:
|
||||
return score
|
|
@ -1,279 +0,0 @@
|
|||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .attention import Attention
|
||||
from .feedforward import FeedForward
|
||||
from .layernorm import LayerNorm
|
||||
from .position_embedding import RotaryEmbedding
|
||||
from .position_embedding import RotaryEmbeddingESM
|
||||
|
||||
|
||||
class SelfAttentionBlock(torch.nn.Module):
|
||||
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
|
||||
|
||||
Args:
|
||||
dim_model (int): main dimension of modules in transformer blocks.
|
||||
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
||||
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
||||
dtype (optional): Defaults to torch.half.
|
||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
||||
dropout_p (float, optional): Defaults to 0.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
num_heads: int,
|
||||
dim_head: int,
|
||||
num_kv_heads: int,
|
||||
dtype=torch.half,
|
||||
eps: float = 1e-6,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
use_flash_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layernorm_before_attention = LayerNorm(
|
||||
dim_model,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.self_attention = Attention(
|
||||
dim_model=dim_model,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
dim_head=dim_head,
|
||||
dtype=dtype,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
|
||||
if dropout_p:
|
||||
self.dropout = torch.nn.Dropout(dropout_p)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_bias: Union[torch.Tensor, RotaryEmbedding, RotaryEmbeddingESM] = None,
|
||||
use_cache: bool = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences.
|
||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation.
|
||||
position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block.
|
||||
|
||||
""" # noqa: E501
|
||||
x = self.layernorm_before_attention(hidden_states)
|
||||
x = self.self_attention(
|
||||
x,
|
||||
x,
|
||||
attention_mask,
|
||||
position_bias,
|
||||
use_cache,
|
||||
past_key_value,
|
||||
pos_bias_type=pos_bias_type,
|
||||
length_mask=length_mask,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
if use_cache:
|
||||
x, current_key_value = x
|
||||
else:
|
||||
current_key_value = None
|
||||
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
hidden_states = hidden_states + x
|
||||
|
||||
if use_cache:
|
||||
return hidden_states, current_key_value
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FFNBlock(torch.nn.Module):
|
||||
"""The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection.
|
||||
|
||||
Args:
|
||||
dim_model (int): main dimension of modules in transformer blocks.
|
||||
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
||||
dtype (optional): Defaults to torch.half.
|
||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
||||
dropout_p (float, optional): Defaults to 0.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dim_ff: int,
|
||||
activate_fn: str,
|
||||
dtype=torch.half,
|
||||
eps: float = 1e-6,
|
||||
dropout_p: Optional[float] = 0,
|
||||
scale: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layernorm_before_ffn = LayerNorm(
|
||||
dim_model,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.ffn = FeedForward(
|
||||
dim_model,
|
||||
dim_ff,
|
||||
activate_fn=activate_fn,
|
||||
dtype=dtype,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
if dropout_p:
|
||||
self.dropout = torch.nn.Dropout(dropout_p)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block
|
||||
|
||||
""" # noqa: E501
|
||||
x = self.layernorm_before_ffn(hidden_states)
|
||||
x = self.ffn(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
hidden_states = hidden_states + x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TransformerBlock(torch.nn.Module):
|
||||
"""The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block.
|
||||
|
||||
Args:
|
||||
dim_model (int): main dimension of modules in transformer blocks.
|
||||
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
||||
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
||||
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
||||
dtype (optional): Defaults to torch.half.
|
||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
||||
dropout_p (float, optional): Defaults to 0.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dim_ff: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
dim_head: int,
|
||||
activate_fn: str = "gelu",
|
||||
dtype=torch.half,
|
||||
eps: float = 1e-6,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
mask_att: bool = False,
|
||||
mask_ffn: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.mask_att = mask_att
|
||||
self.mask_ffn = mask_ffn
|
||||
|
||||
if not self.mask_att:
|
||||
self.self_att = SelfAttentionBlock(
|
||||
dim_model=dim_model,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
dim_head=dim_head,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
|
||||
if not self.mask_ffn:
|
||||
self.ffn = FFNBlock(
|
||||
dim_model=dim_model,
|
||||
dim_ff=dim_ff,
|
||||
activate_fn=activate_fn,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self_hidden_states: torch.Tensor,
|
||||
self_attention_mask: torch.Tensor,
|
||||
self_position_bias: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
||||
self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention.
|
||||
self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block.
|
||||
|
||||
""" # noqa: E501
|
||||
# (batch, dim_model, seq_self)
|
||||
current_key_value = None
|
||||
if not self.mask_att:
|
||||
hidden_states = self.self_att(
|
||||
self_hidden_states,
|
||||
attention_mask=self_attention_mask,
|
||||
position_bias=self_position_bias,
|
||||
use_cache=use_cache,
|
||||
past_key_value=past_key_value,
|
||||
pos_bias_type=pos_bias_type,
|
||||
length_mask=length_mask,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
if use_cache:
|
||||
hidden_states, current_key_value = hidden_states
|
||||
else:
|
||||
hidden_states = self_hidden_states
|
||||
|
||||
# (batch, dim_model, seq_self)
|
||||
if not self.mask_ffn:
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
|
||||
if use_cache:
|
||||
return hidden_states, current_key_value
|
||||
else:
|
||||
return hidden_states
|
|
@ -1,100 +0,0 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .position_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class Embedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embedding_size: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
scale: bool = True,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = embedding_size
|
||||
self.weight = torch.nn.parameter.Parameter(torch.empty(vocab_size, embedding_size, dtype=dtype))
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, ids: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
||||
""" # noqa: E501
|
||||
|
||||
if self.scale:
|
||||
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
||||
else:
|
||||
embeds = F.embedding(ids, self.weight)
|
||||
return embeds.clone()
|
||||
|
||||
def projection(self, x: torch.Tensor):
|
||||
"""
|
||||
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
|
||||
""" # noqa: E501
|
||||
if self.scale:
|
||||
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
||||
else:
|
||||
logits = F.linear(x, self.weight)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class EmbeddingExt(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embedding_size: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
distance_scale: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = embedding_size
|
||||
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
|
||||
|
||||
self.weight = torch.nn.parameter.Parameter(
|
||||
torch.empty(vocab_size, embedding_size, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
||||
ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
||||
""" # noqa: E501
|
||||
|
||||
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
||||
return self.rotary_emb(embeds, ids_sub)
|
||||
|
||||
def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
||||
ext_table (:obj:`torch.Tensor` of shape ``(ext_table_size, dim_model)``): Ext vocab table.
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_size + ext_table_size)``: The projection output.
|
||||
""" # noqa: E501
|
||||
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
||||
if ext_table is not None:
|
||||
logits_ext = F.linear(x, ext_table)
|
||||
logits = torch.cat([logits, logits_ext], dim=-1)
|
||||
return logits
|
|
@ -1,120 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
class DenseGatedACT(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_ff: int,
|
||||
dtype=torch.half,
|
||||
activate_fn: str = "gelu",
|
||||
scale: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.w_0 = Linear(
|
||||
dim_in=dim_in,
|
||||
dim_out=dim_ff,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
scale_before=False,
|
||||
)
|
||||
|
||||
self.w_1 = Linear(
|
||||
dim_in=dim_in,
|
||||
dim_out=dim_ff,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
scale_before=False,
|
||||
)
|
||||
if activate_fn == "gelu":
|
||||
self.act = torch.nn.GELU()
|
||||
elif activate_fn == "silu":
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
raise NotImplementedError(f"{activate_fn} is not supported")
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
||||
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations.
|
||||
|
||||
Return:
|
||||
out (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_ff)``)
|
||||
|
||||
""" # noqa: E501
|
||||
gate_score = self.act(self.w_0(x))
|
||||
x = self.w_1(x)
|
||||
|
||||
x = gate_score * x
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(torch.nn.Module):
|
||||
r"""FeedForward module
|
||||
|
||||
Args:
|
||||
dim_in (int): input dimension.
|
||||
dim_ff (int): middle dimension.
|
||||
dim_out (int, optional): output dimension. Defaults to None, which means dim_in = dim_out.
|
||||
dtype (optional): Defaults to torch.half.
|
||||
init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.
|
||||
init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02.
|
||||
bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False.
|
||||
activate_fn (str, optional): Defaults to `gated_gelu`.
|
||||
dropout_p (int, optional): Defaults to 0.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dim_ff: int,
|
||||
activate_fn: str = "gelu",
|
||||
dtype=torch.half,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.w_in = DenseGatedACT(
|
||||
dim_in=dim_model,
|
||||
dim_ff=dim_ff,
|
||||
activate_fn=activate_fn,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
if dropout_p is not None:
|
||||
self.dropout = torch.nn.Dropout(dropout_p)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
self.w_out = Linear(
|
||||
dim_in=dim_ff,
|
||||
dim_out=dim_model,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
scale_before=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of feed-forward module.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of feed-forward module.
|
||||
""" # noqa: E501
|
||||
x = self.w_in(x)
|
||||
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.w_out(x)
|
||||
|
||||
return x
|
|
@ -1,37 +0,0 @@
|
|||
import torch
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
||||
old_dtype = hidden.dtype
|
||||
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||||
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
|
||||
return hidden * weight
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
"""RMS LayerNorm"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_norm: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
eps: float = 1e-6,
|
||||
init_var: float = 1.0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.dim_norm = dim_norm
|
||||
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``): Input tensor that need to be normalized.
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output.
|
||||
"""
|
||||
assert x.size(-1) == self.dim_norm
|
||||
return rms_layernorm(x, self.weight, self.eps)
|
|
@ -1,44 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_out: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
scale: bool = True,
|
||||
scale_before: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_in = self.in_features = dim_in
|
||||
self.dim_out = self.out_features = dim_out
|
||||
self.scale = scale
|
||||
self.scale_before = scale_before
|
||||
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
|
||||
torch.nn.init.normal_(self.weight, mean=init_mean, std=init_std)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
||||
""" # noqa: E501
|
||||
if self.scale:
|
||||
if self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
|
||||
x = F.linear(x, self.weight)
|
||||
else:
|
||||
x = F.linear(x, self.weight)
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
|
||||
else:
|
||||
x = F.linear(x, self.weight)
|
||||
return x
|
|
@ -1,286 +0,0 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SegmentPositionEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_segments: int = 1,
|
||||
num_buckets: int = 32,
|
||||
max_distance: int = 128,
|
||||
bidirectional: bool = False,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
self.bidirectional = bidirectional
|
||||
self.num_segments = num_segments
|
||||
|
||||
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
||||
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
key_pos: torch.Tensor,
|
||||
query_pos: torch.Tensor,
|
||||
key_segment: torch.Tensor,
|
||||
query_segment: torch.Tensor,
|
||||
):
|
||||
with torch.no_grad():
|
||||
batch = key_pos.size(0)
|
||||
keylen = key_pos.size(1)
|
||||
querylen = query_pos.size(1)
|
||||
|
||||
assert key_pos.size(0) == query_pos.size(0)
|
||||
assert keylen == key_segment.size(1) and querylen == query_segment.size(1)
|
||||
|
||||
key_pos = key_pos.view(batch, -1, keylen)
|
||||
query_pos = query_pos.view(batch, querylen, -1)
|
||||
key_segment = key_segment.view(batch, -1, keylen)
|
||||
query_segment = query_segment.view(batch, querylen, -1)
|
||||
|
||||
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
|
||||
relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠
|
||||
|
||||
# b*q*k
|
||||
absolute_position_bucket = self._position_bucket(
|
||||
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
|
||||
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
relative_position_bucket = torch.where(
|
||||
(key_segment == query_segment),
|
||||
absolute_position_bucket[None, :, :],
|
||||
relative_position_bucket,
|
||||
)
|
||||
# (batch, len_q, len_k)
|
||||
|
||||
# (batch, len_q, len_k, num_heads)
|
||||
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
||||
# (batch, num_heads, len_q, len_k)
|
||||
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
||||
return embeds
|
||||
|
||||
def _segment_relative_position_bucket(self, query_segment, key_segment):
|
||||
return query_segment * self.num_segments + key_segment
|
||||
|
||||
def _position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.int32)
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large,
|
||||
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
||||
)
|
||||
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
|
||||
class BucketPositionBias(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_buckets: int = 32,
|
||||
num_segment_bucket: int = 32,
|
||||
max_distance: int = 128,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.num_segment_bucket = num_segment_bucket
|
||||
self.max_distance = max_distance
|
||||
|
||||
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
||||
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_pos: torch.Tensor, # (batch, len_q)
|
||||
key_pos: torch.Tensor, # (batch, len_k)
|
||||
rel_buckets: torch.Tensor, # (batch, len_q, len_k)
|
||||
):
|
||||
with torch.no_grad():
|
||||
batch = key_pos.size(0)
|
||||
keylen = key_pos.size(1)
|
||||
querylen = query_pos.size(1)
|
||||
|
||||
assert key_pos.size(0) == query_pos.size(0)
|
||||
assert rel_buckets.size(0) == batch and rel_buckets.size(1) == querylen and rel_buckets.size(2) == keylen
|
||||
|
||||
relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠
|
||||
|
||||
# b*q*k
|
||||
inner_segment_bucket = self._position_bucket(
|
||||
key_pos[..., None, :] - query_pos[..., :, None],
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
relative_position_bucket = torch.where(
|
||||
rel_buckets == 0,
|
||||
inner_segment_bucket,
|
||||
relative_position_bucket,
|
||||
)
|
||||
# (batch, len_q, len_k)
|
||||
|
||||
# (batch, len_q, len_k, num_heads)
|
||||
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
||||
# (batch, num_heads, len_q, len_k)
|
||||
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
||||
return embeds
|
||||
|
||||
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
|
||||
relative_buckets = 0
|
||||
num_buckets //= 2
|
||||
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.int32)
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large,
|
||||
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
||||
)
|
||||
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
base=10000,
|
||||
distance_scale: Union[int, float] = 1,
|
||||
dtype: torch.dtype = torch.half,
|
||||
):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
||||
inv_freq = inv_freq.to(dtype)
|
||||
self.distance_scale = distance_scale
|
||||
self.dtype = dtype
|
||||
self.inv_freq = inv_freq
|
||||
|
||||
def forward(self, x: torch.Tensor, x_pos: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(..., dim)``): Inputs.
|
||||
x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs.
|
||||
"""
|
||||
x_pos = x_pos * self.distance_scale
|
||||
freqs = x_pos[..., None].to(self.dtype) * self.inv_freq[None, :] # (..., dim/2)
|
||||
|
||||
# the same implementation as sat
|
||||
emb = torch.cat((freqs, freqs), dim=-1) # (..., dim)
|
||||
emb_cos = emb.cos() # (..., dim)
|
||||
emb_sin = emb.sin() # (..., dim)
|
||||
|
||||
rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim)
|
||||
|
||||
return x * emb_cos + rotate_x * emb_sin
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset):
|
||||
if x.size(seq_dim) < cos.size(seq_dim):
|
||||
cos = cos.narrow(seq_dim, offset, x.size(seq_dim))
|
||||
sin = sin.narrow(seq_dim, offset, x.size(seq_dim))
|
||||
return (x * cos) + (rotate_half(x) * sin)
|
||||
|
||||
|
||||
class RotaryEmbeddingESM(torch.nn.Module):
|
||||
"""
|
||||
Rotary position embeddings based on those in
|
||||
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
|
||||
matrices which depend on their relative positions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base: Union[int, float] = 10000,
|
||||
distance_scale: Union[int, float] = 1,
|
||||
dtype=torch.half,
|
||||
persistent=True,
|
||||
mixed_precision=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.base = base
|
||||
self.distance_scale = distance_scale
|
||||
self.dtype = dtype
|
||||
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
||||
if mixed_precision:
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=persistent)
|
||||
else:
|
||||
self.register_buffer("inv_freq", inv_freq.to(self.dtype), persistent=persistent)
|
||||
|
||||
self._seq_len_cached = -1
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self.mixed_precision = mixed_precision
|
||||
|
||||
self.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
|
||||
def _update_cos_sin_tables(self, x, seq_dim, offset):
|
||||
seq_len = x.size(seq_dim) + offset
|
||||
if seq_len > self._seq_len_cached or self._cos_cached.device != x.device:
|
||||
self._seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
||||
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
for i in range(x.dim() - 1):
|
||||
if i != seq_dim:
|
||||
emb = emb.unsqueeze_(i)
|
||||
if self.mixed_precision:
|
||||
self._cos_cached = emb.cos().to(self.dtype)
|
||||
self._sin_cached = emb.sin().to(self.dtype)
|
||||
else:
|
||||
self._cos_cached = emb.cos()
|
||||
self._sin_cached = emb.sin()
|
||||
return self._cos_cached, self._sin_cached
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_dim = (seq_dim + k.dim()) % k.dim()
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, offset)
|
||||
return (
|
||||
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||
)
|
|
@ -1,132 +0,0 @@
|
|||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .blocks import TransformerBlock
|
||||
from .layernorm import LayerNorm
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Layers of encoder transformer blocks plus an final layernorm.
|
||||
|
||||
Args:
|
||||
num_layers (int): number of layers.
|
||||
dim_model (int): main dimension of modules in transformer blocks.
|
||||
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
||||
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
||||
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
||||
dtype (optional): Defaults to torch.half.
|
||||
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
|
||||
dropout_p (float, optional): Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
dim_model: int,
|
||||
dim_ff: int,
|
||||
num_heads: int,
|
||||
dim_head: int,
|
||||
num_kv_heads: int = -1,
|
||||
activate_fn: str = "gelu",
|
||||
dtype: torch.dtype = torch.half,
|
||||
eps: float = 1e-6,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
||||
use_flash_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if num_kv_heads == -1:
|
||||
num_kv_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
|
||||
if mask_modules is not None:
|
||||
assert len(mask_modules) == num_layers, "The total number of masks should equal to num_layers"
|
||||
for mask_module in mask_modules:
|
||||
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
|
||||
else:
|
||||
mask_modules = [(False, False)] * num_layers
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
dim_model=dim_model,
|
||||
dim_ff=dim_ff,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
dim_head=dim_head,
|
||||
activate_fn=activate_fn,
|
||||
dtype=dtype,
|
||||
eps=eps,
|
||||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
mask_att=mask_modules[ith][0],
|
||||
mask_ffn=mask_modules[ith][1],
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for ith in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_bias: torch.Tensor,
|
||||
use_cache: bool = False,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
context_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden-states (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Input of encoder, might be the embedding of a batch of sequences.
|
||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc, seq_enc)``): Avoid invalid areas to participate in the calculation
|
||||
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, seq_enc, seq_enc)``) Provides position information to attention mechanism.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``: The encoder output.
|
||||
|
||||
"""
|
||||
if not use_cache:
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_bias,
|
||||
pos_bias_type=pos_bias_type,
|
||||
length_mask=length_mask,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
hidden_states = self.output_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
else:
|
||||
with torch.no_grad():
|
||||
current_key_values = []
|
||||
current_hidden_states = []
|
||||
for i, module in enumerate(self.layers):
|
||||
hidden_states = module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_bias,
|
||||
past_key_value=past_key_values[i] if past_key_values else None,
|
||||
use_cache=use_cache,
|
||||
pos_bias_type=pos_bias_type,
|
||||
length_mask=length_mask,
|
||||
context_mask=context_mask,
|
||||
)
|
||||
if use_cache:
|
||||
current_key_values.append(hidden_states[1])
|
||||
current_hidden_states.append(hidden_states[0])
|
||||
hidden_states = hidden_states[0]
|
||||
hidden_states = self.output_layernorm(hidden_states)
|
||||
if use_cache:
|
||||
return hidden_states, current_key_values, current_hidden_states
|
||||
else:
|
||||
return hidden_states
|
|
@ -1,62 +0,0 @@
|
|||
import torch
|
||||
|
||||
|
||||
def pad(orig_items, key, padding_value=0, padding_side="left"):
|
||||
items = []
|
||||
if isinstance(orig_items[0][key], list):
|
||||
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
||||
for it in orig_items:
|
||||
for tr in it[key]:
|
||||
items.append({key: tr})
|
||||
else:
|
||||
assert isinstance(orig_items[0][key], torch.Tensor)
|
||||
items = orig_items
|
||||
|
||||
batch_size = len(items)
|
||||
shape = items[0][key].shape
|
||||
dim = len(shape)
|
||||
assert dim <= 3
|
||||
max_length = max(item[key].shape[-1] for item in items)
|
||||
min_length = min(item[key].shape[-1] for item in items)
|
||||
dtype = items[0][key].dtype
|
||||
|
||||
if dim == 1:
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
elif dim == 2:
|
||||
if max_length == min_length:
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||
else:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||
|
||||
for i, item in enumerate(items):
|
||||
if dim == 2:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]) :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
||||
elif dim == 3:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def pad_raw(orig_items, max_length=1024, padding_value=0, padding_side="left"):
|
||||
max_cols = max(tensor.size(1) for tensor in orig_items)
|
||||
|
||||
padded_arrays = []
|
||||
for tensor in orig_items:
|
||||
pad_cols = max_cols - tensor.size(1)
|
||||
if padding_side == "left":
|
||||
padded_tensor = torch.cat([torch.zeros(tensor.size(0), pad_cols), tensor], dim=1)
|
||||
elif padding_side == "right":
|
||||
padded_tensor = torch.cat([tensor, torch.zeros(tensor.size(0), pad_cols)], dim=1)
|
||||
else:
|
||||
raise ValueError("Invalid 'side' parameter. Must be 'left' or 'right'.")
|
||||
padded_arrays.append(padded_tensor)
|
||||
|
||||
padded_tensor = torch.cat(padded_arrays, dim=0).to(dtype=torch.int32)
|
||||
return padded_tensor
|
|
@ -1,130 +0,0 @@
|
|||
import functools
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
from .log import logger
|
||||
|
||||
|
||||
def rename_if_exists(file_path):
|
||||
if not os.path.exists(file_path):
|
||||
return
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
file_dir, file_name = os.path.split(file_path)
|
||||
file_root, file_ext = os.path.splitext(file_name)
|
||||
new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
|
||||
new_file_path = os.path.join(file_dir, new_file_name)
|
||||
try:
|
||||
os.rename(file_path, new_file_path)
|
||||
logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"rename file failed,file_path={file_path}, new_file_path={new_file_path},err={err}".format(
|
||||
file_path=file_path, new_file_path=new_file_path, err=str(e)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rename_if_exists_decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(file_path, *args, **kwargs):
|
||||
rename_if_exists(file_path)
|
||||
return func(file_path, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@rename_if_exists_decorator
|
||||
def bmt_save(file_path: str, model: torch.nn.Module, export_files: List[str] = None):
|
||||
bmt.save(model, file_path)
|
||||
if export_files is not None:
|
||||
export_files.append(file_path)
|
||||
|
||||
|
||||
@rename_if_exists_decorator
|
||||
def torch_save(file_path: str, obj: object, export_files: List[str] = None):
|
||||
torch.save(obj, file_path)
|
||||
if export_files is not None:
|
||||
export_files.append(file_path)
|
||||
|
||||
|
||||
@rename_if_exists_decorator
|
||||
def json_save(file_path: str, obj: object, export_files: List[str] = None):
|
||||
with open(file_path, "w") as data_f:
|
||||
json.dump(obj, data_f)
|
||||
if export_files is not None:
|
||||
export_files.append(file_path)
|
||||
|
||||
|
||||
def export(
|
||||
model: torch.nn.Module, dataloader, optimizer: bmt.optim.AdamOffloadOptimizer, global_step, args, final_save=False
|
||||
):
|
||||
"""
|
||||
一次 ckpt 保存:
|
||||
/{args.save}/
|
||||
├── {save_name}-{global_step}.rank-0.opt
|
||||
├── {save_name}-{global_step}.rank-n.opt
|
||||
├── job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时,job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
|
||||
├── config.json
|
||||
├── vocabs.txt
|
||||
├── {args.save_name}-{global_step}.pt
|
||||
├── {args.save_name}-{global_step}.data
|
||||
├── {args.save_name}-{global_step}.data.json
|
||||
└── {args.save_name}-{global_step}.success
|
||||
|
||||
"""
|
||||
export_model_dir = os.path.join(args.save, f"l_{global_step}")
|
||||
os.makedirs(export_model_dir, exist_ok=True)
|
||||
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
|
||||
logger.info(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
|
||||
export_files = []
|
||||
|
||||
# model checkpoint
|
||||
bmt_save(
|
||||
file_path=os.path.join(export_model_dir, base_file_name + ".pt"),
|
||||
model=model,
|
||||
export_files=export_files,
|
||||
)
|
||||
|
||||
# opt is only used for continual pre-training, not the final opt
|
||||
if not final_save:
|
||||
grad_path = os.path.join(
|
||||
args.save,
|
||||
args.save_name + ("-%d.rank-%d.opt" % (global_step % (args.save_iters * 5), bmt.rank())),
|
||||
)
|
||||
torch.save(optimizer.state_dict(), grad_path)
|
||||
logger.info(f"Successfully save grad file: {grad_path}")
|
||||
|
||||
all_states = dataloader.state_dict()
|
||||
if bmt.rank() == 0:
|
||||
# data checkpoint
|
||||
# rank 0 writes the dataloader state
|
||||
torch_save(
|
||||
file_path=os.path.join(export_model_dir, base_file_name + ".data"),
|
||||
obj=all_states,
|
||||
export_files=export_files,
|
||||
)
|
||||
# data checkpoint json
|
||||
# rank 0 writes the dataloader state into the json file
|
||||
data_p_json = {k: v for k, v in all_states.items()}
|
||||
for k in data_p_json:
|
||||
data_p_json[k] = {k_of_v: data_p_json[k][k_of_v].tolist() for k_of_v in data_p_json[k]}
|
||||
json_save(
|
||||
file_path=os.path.join(export_model_dir, base_file_name + ".data.json"),
|
||||
obj=data_p_json,
|
||||
export_files=export_files,
|
||||
)
|
||||
# config 和 vocabs 和模型文件一起存储
|
||||
model_cfg_path = os.path.join(export_model_dir, "config.json")
|
||||
model_vocab_path = os.path.join(export_model_dir, "vocabs.txt")
|
||||
export_files.extend([model_cfg_path, model_vocab_path])
|
||||
shutil.copy(args.model_config, model_cfg_path)
|
||||
shutil.copy(args.vocab, model_vocab_path)
|
||||
logger.info(f"Successfully save model files! {export_files}")
|
||||
del all_states
|
||||
return export_model_dir
|
|
@ -0,0 +1,4 @@
|
|||
# !/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024, QiYuan Inc
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = "<用户>"+data['input']+"<AI>"
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = data['input']
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
[
|
||||
{
|
||||
"dataset_name": "humanevallike_clean_dedup",
|
||||
"task_name": "humanevallike_clean_dedup",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 995339,
|
||||
"ave_tokens_per_line": 100,
|
||||
"total_tokens": 0.1
|
||||
},
|
||||
{
|
||||
"dataset_name": "leetcode_pass_code_0125",
|
||||
"task_name": "leetcode_pass_code_0125",
|
||||
"abs_weight": 0.006,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 10724,
|
||||
"ave_tokens_per_line": 200,
|
||||
"total_tokens": 0.002
|
||||
},
|
||||
{
|
||||
"dataset_name": "logiv2Annotate",
|
||||
"task_name": "logiv2Annotate",
|
||||
"abs_weight": 0.004,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 12566,
|
||||
"ave_tokens_per_line": 512,
|
||||
"total_tokens": 0.006
|
||||
},
|
||||
{
|
||||
"dataset_name": "mmlu_enhance",
|
||||
"task_name": "mmlu_enhance",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 169771,
|
||||
"ave_tokens_per_line": 300,
|
||||
"total_tokens": 0.05
|
||||
},
|
||||
{
|
||||
"dataset_name": "mtbench_like",
|
||||
"task_name": "mtbench_like",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 319080,
|
||||
"ave_tokens_per_line": 500,
|
||||
"total_tokens": 0.15
|
||||
},
|
||||
{
|
||||
"dataset_name": "ultra_dataset_new",
|
||||
"task_name": "ultra_dataset_new",
|
||||
"abs_weight": 2.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 385045,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 2.0
|
||||
},
|
||||
{
|
||||
"dataset_name": "sft_data_zh_wowru",
|
||||
"task_name": "sft_data_zh_wowru",
|
||||
"abs_weight": 1.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 1
|
||||
},
|
||||
{
|
||||
"dataset_name": "math_data",
|
||||
"task_name": "math_data",
|
||||
"abs_weight": 0.003,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 0.005
|
||||
},
|
||||
{
|
||||
"dataset_name": "t0",
|
||||
"task_name": "t0",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 1650309,
|
||||
"ave_tokens_per_line": 500.296266559615,
|
||||
"total_tokens": 0.82
|
||||
},
|
||||
{
|
||||
"dataset_name": "wikihow",
|
||||
"task_name": "wikihow",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 180128,
|
||||
"ave_tokens_per_line": 900.296266559615,
|
||||
"total_tokens": 0.16
|
||||
},
|
||||
{
|
||||
"dataset_name": "reclor",
|
||||
"task_name": "reclor",
|
||||
"abs_weight": 0.002,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 4174,
|
||||
"ave_tokens_per_line": 700.296266559615,
|
||||
"total_tokens": 0.003
|
||||
},
|
||||
{
|
||||
"dataset_name": "logic_test_lx_0127",
|
||||
"task_name": "logic_test_lx_0127",
|
||||
"abs_weight": 0.001,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2800,
|
||||
"ave_tokens_per_line": 200.96266559615,
|
||||
"total_tokens": 0.0004
|
||||
}
|
||||
]
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"vocab_size": 122753,
|
||||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"half_type": "bf16",
|
||||
"use_flash_attn": true,
|
||||
"flash_attn_mask_shape": "2d",
|
||||
"dim_model": 2304,
|
||||
"dim_ff": 5760,
|
||||
"dim_head": 64,
|
||||
"num_heads": 36,
|
||||
"num_kv_heads": 36,
|
||||
"num_layers": 40,
|
||||
"activate_fn": "silu",
|
||||
"init_std": 0.10,
|
||||
"scale": true,
|
||||
"scale_emb": 12,
|
||||
"scale_depth": 1.4,
|
||||
"dim_model_base": 256,
|
||||
"model_type": "fm9g",
|
||||
"architectures": [
|
||||
"FM9GForCausalLM"
|
||||
],
|
||||
"qk_norm": false,
|
||||
"tie_lm_head": true,
|
||||
"ffn_gated": true
|
||||
}
|
|
@ -0,0 +1,548 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 QiYuan Inc.
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from bmtrain import nccl
|
||||
from bmtrain.global_var import config as bmt_config
|
||||
|
||||
sys.path.append("../../")
|
||||
from fm9g.arguments import get_args
|
||||
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
||||
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
|
||||
from fm9g.utils import exporter
|
||||
from fm9g.utils import logger
|
||||
from fm9g.utils.exporter import save_every_step_stats
|
||||
from fm9g.utils.training_stats import num_non_embedding_parameters
|
||||
from fm9g.utils.training_stats import num_parameters
|
||||
|
||||
|
||||
def get_tokenizer(args):
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
tokenizer = LlamaTokenizerFast(vocab_file=args.tokenizer_path)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(args):
|
||||
config = DragonflyConfig.from_json_file(args.model_config)
|
||||
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
||||
config.pose_prob = args.pose_prob
|
||||
config.pose_scaling_factor = args.pose_scaling_factor
|
||||
config.rope_scaling_type = args.rope_scaling_type
|
||||
config.rope_scaling_factor = args.rope_scaling_factor
|
||||
config.orig_max_length = args.orig_max_length
|
||||
|
||||
bmt.print_rank("model config: {}".format(config))
|
||||
bmt.print_rank("bmt config: {}".format(bmt.config))
|
||||
|
||||
model = Dragonfly(config)
|
||||
if args.load is not None:
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||
exporter.load_model_ckpt(args, model)
|
||||
else:
|
||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||
bmt.init_parameters(model)
|
||||
return model
|
||||
|
||||
|
||||
def get_optimizer(args, model):
|
||||
scale_lr_group = []
|
||||
normal_group = []
|
||||
scale_lr_group_name, normal_group_name = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
||||
scale_lr_group.append(p)
|
||||
scale_lr_group_name.append(n)
|
||||
else:
|
||||
normal_group.append(p)
|
||||
normal_group_name.append(n)
|
||||
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
||||
param_groups = [
|
||||
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
||||
{"params": normal_group, "lr": args.lr},
|
||||
]
|
||||
|
||||
if args.offload:
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
else:
|
||||
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
if args.load is not None and args.load_grad:
|
||||
exporter.load_optimizer_ckpt(args, optimizer)
|
||||
bmt.print_rank("optimizer is loaded!")
|
||||
return optimizer
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
from fm9g.training_utils.lr_scheduler import Cosine
|
||||
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
||||
|
||||
end_iter = args.train_iters
|
||||
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
||||
warmup_iters = int(end_iter * args.warmup_iters)
|
||||
else:
|
||||
warmup_iters = int(args.warmup_iters)
|
||||
|
||||
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
||||
drop_iters = int(end_iter * args.drop_iters)
|
||||
else:
|
||||
drop_iters = int(args.drop_iters)
|
||||
|
||||
if args.lr_scheduler == "cosine":
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
num_iter=args.start_step,
|
||||
#lr_end_restart=args.lr_end_restart,
|
||||
#resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
elif args.lr_scheduler == "warmupstabledrop":
|
||||
lr_scheduler = WarmupStableDrop(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
drop_iter=drop_iters,
|
||||
num_iter=args.start_step,
|
||||
resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
start = time.time()
|
||||
tokenizer = get_tokenizer(args)
|
||||
bmt.synchronize()
|
||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
model = get_model(args)
|
||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
optimizer = get_optimizer(args, model)
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def resume_training(args):
|
||||
ckpts = sorted(
|
||||
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
||||
)
|
||||
# find newest job
|
||||
ckpts = sorted(
|
||||
ckpts,
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
||||
)
|
||||
|
||||
if len(ckpts) > 0:
|
||||
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
||||
args.load = ckpts[0]
|
||||
# by default, do not load grad file
|
||||
args.load_grad = False
|
||||
args.start_step = 0
|
||||
else:
|
||||
# no ckpts, nothing we can do
|
||||
os._exit(1)
|
||||
|
||||
|
||||
def initialize():
|
||||
args = get_args(pretrain=True)
|
||||
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
||||
|
||||
if args.save is not None:
|
||||
os.makedirs(args.save, exist_ok=True)
|
||||
if args.load is not None:
|
||||
if args.only_load_model == 0:
|
||||
if args.start_step == 0:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
if "iteration" in log_ckpt:
|
||||
args.start_step = log_ckpt["iteration"]
|
||||
else:
|
||||
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
||||
logger.info("Start from step {}".format(args.start_step))
|
||||
elif args.only_load_model == 1:
|
||||
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
logger.info("You do not load model")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def see_memory(detail=False):
|
||||
if detail:
|
||||
res = torch.cuda.memory_summary()
|
||||
else:
|
||||
res = (
|
||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return res
|
||||
|
||||
|
||||
def add_mem_time(info, mem_usage, tim_usage):
|
||||
torch.cuda.synchronize()
|
||||
bmt.synchronize()
|
||||
mem_usage[info] = see_memory()
|
||||
tim_usage[info] = time.time()
|
||||
return mem_usage, tim_usage
|
||||
|
||||
|
||||
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
||||
# task_ids 可能有-1 来代表无效token
|
||||
_task_num = task_num + 1
|
||||
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
||||
# gen masks
|
||||
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
||||
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
||||
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
||||
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
||||
# calc loss and tokens
|
||||
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
# return token-wise avg losses and tokens
|
||||
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
||||
|
||||
|
||||
class ChunkAve:
|
||||
def __init__(self, chunk_size=100):
|
||||
self.ave_list = []
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def record(self, time):
|
||||
self.ave_list.append(time)
|
||||
self.ave_list = self.ave_list[-self.chunk_size :]
|
||||
|
||||
def get(self):
|
||||
return sum(self.ave_list) / len(self.ave_list)
|
||||
|
||||
|
||||
def pretrain(
|
||||
args,
|
||||
tokenizer,
|
||||
model: Dragonfly,
|
||||
optimizer,
|
||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||
):
|
||||
ave_model_time = ChunkAve(chunk_size=100)
|
||||
ave_iter_time = ChunkAve(chunk_size=100)
|
||||
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||
optim_manager = bmt.optim.OptimManager(
|
||||
loss_scale=None,
|
||||
loss_scale_steps=args.loss_scale_steps,
|
||||
loss_scale_factor=2,
|
||||
max_loss_scale=args.max_loss_scale,
|
||||
min_loss_scale=args.min_loss_scale,
|
||||
)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
|
||||
start_step = args.start_step
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
import distutils.version # noqa: F401
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
if not os.path.exists(args.tensorboard):
|
||||
os.makedirs(args.tensorboard)
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
|
||||
if args.load is not None:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
else:
|
||||
log_ckpt = {}
|
||||
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
||||
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
||||
|
||||
global_world_size = bmt.world_size()
|
||||
bmt.print_rank("Begin preparing dataset")
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset = MixedIndexedDataset(
|
||||
cfg_path=args.dataset,
|
||||
cfg_json_str=None,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
nthreads=args.dataloader_num_threads,
|
||||
prefetch_slice=args.dataloader_prefetch,
|
||||
weight_by_size=True,
|
||||
)
|
||||
|
||||
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
||||
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
||||
|
||||
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
batched_dataset,
|
||||
batch_size=None,
|
||||
collate_fn=lambda x: x,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
prefetch_factor=args.dataloader_prefetch_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
def dummy_generator():
|
||||
while True:
|
||||
yield None
|
||||
|
||||
mixed_indexed_dataset = dummy_generator()
|
||||
dataloader = mixed_indexed_dataset
|
||||
|
||||
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
||||
|
||||
bmt.print_rank("Preparing dataset done.")
|
||||
|
||||
# inspect at init
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
try:
|
||||
mem_usage, tim_usage = {}, {}
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
||||
|
||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||
|
||||
logits = model(
|
||||
input=data["inputs"],
|
||||
cu_seqlens=data["cu_seqlens"],
|
||||
max_seqlen=data["max_seqlen"],
|
||||
position_ids=data["position_ids"],
|
||||
)
|
||||
|
||||
# chunk targets and task_ids
|
||||
data["targets"] = (
|
||||
data["targets"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["targets"].shape[0], -1)
|
||||
)
|
||||
data["task_ids"] = (
|
||||
data["task_ids"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["task_ids"].shape[0], -1)
|
||||
)
|
||||
|
||||
_target = data["targets"].view(-1)
|
||||
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
||||
_w = (_target != -100).int()
|
||||
loss = non_reduced_loss.sum() / _w.sum().float()
|
||||
|
||||
global_loss = bmt.sum_loss(loss).item()
|
||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||
|
||||
optim_manager.backward(loss)
|
||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||
|
||||
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
||||
grad_accum_init_time = tim_usage["init"]
|
||||
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
optim_manager.zero_grad()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
model_time = tim_usage["optim"] - grad_accum_init_time
|
||||
ave_model_time.record(model_time)
|
||||
else:
|
||||
# dummy optim step
|
||||
grad_norm = torch.Tensor([0.0]).cuda()
|
||||
tim_usage["optim"] = tim_usage["backward"]
|
||||
mem_usage["optim"] = mem_usage["backward"]
|
||||
|
||||
with torch.no_grad():
|
||||
task_num = len(data["task_names"])
|
||||
task_loss, task_token = get_task_loss_and_token(
|
||||
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
||||
)
|
||||
task_loss_map: Dict[str, float] = {}
|
||||
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
||||
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
||||
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
||||
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
||||
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
||||
ave_task_loss = sum_task_loss / tot_task_token
|
||||
for i in range(task_num):
|
||||
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
||||
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
||||
|
||||
local_total_rate = torch.Tensor([data["lengths"].float().mean() / args.max_length]).cuda()
|
||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||
global_token_pass += (
|
||||
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
)
|
||||
last_before_log_time = tim_usage["before_log"]
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
iter_time = tim_usage["before_log"] - last_before_log_time
|
||||
|
||||
ave_iter_time.record(iter_time)
|
||||
|
||||
train_info = {
|
||||
"time": iter_time,
|
||||
"iteration": iteration,
|
||||
"loss": global_loss,
|
||||
"lr": lr_scheduler.current_lr,
|
||||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"task_loss": task_loss_map,
|
||||
"total_task_token": global_total_task_token,
|
||||
}
|
||||
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
||||
|
||||
bmt.print_rank(
|
||||
(
|
||||
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time:.2f} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
||||
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
||||
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
||||
).format(
|
||||
iteration=iteration,
|
||||
loss=global_loss,
|
||||
lr=lr_scheduler.current_lr,
|
||||
model_time=model_time,
|
||||
iter_time=iter_time,
|
||||
chunk_ave_time=ave_iter_time.get(),
|
||||
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
||||
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
||||
grad_norm=grad_norm.item(),
|
||||
global_token_pass=global_token_pass_str,
|
||||
mem_usage=max([value for key, value in mem_usage.items()]),
|
||||
)
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"task_loss:\t| "
|
||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % 10 == 0:
|
||||
bmt.print_rank(
|
||||
"task_tokens (B):\t| "
|
||||
+ " | ".join(
|
||||
[
|
||||
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
||||
for task_name, task_token in global_total_task_token.items()
|
||||
]
|
||||
)
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % args.inspect_iters == 0:
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
if args.log_dir is not None and bmt.rank() == 0:
|
||||
if args.save is not None:
|
||||
save_every_step_stats(train_info, args.save)
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||
for task_name, loss in task_loss_map.items():
|
||||
if not math.isnan(loss):
|
||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||
|
||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||
log_ckpt = {
|
||||
"global_total_task_token": global_total_task_token,
|
||||
"global_token_pass": global_token_pass,
|
||||
"iteration": iteration,
|
||||
}
|
||||
|
||||
if args.save is not None and iteration % args.save_iters == 0:
|
||||
exporter.export(
|
||||
model,
|
||||
mixed_indexed_dataset,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
iteration,
|
||||
args,
|
||||
log_ckpt=log_ckpt,
|
||||
final_save=False,
|
||||
)
|
||||
|
||||
if iteration == args.train_iters and args.stop_when_end == 1:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"train loop err: {e}")
|
||||
raise e
|
||||
finally:
|
||||
pass
|
||||
|
||||
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
||||
|
||||
|
||||
def convert_to_k_and_b(number):
|
||||
if number >= 1e9: # 大于或等于10亿
|
||||
b_number = number / 1e9
|
||||
return f"{b_number:.2f}B"
|
||||
elif number >= 1e6: # 大于或等于1百万
|
||||
k_number = number / 1e6
|
||||
return f"{k_number:.2f}M"
|
||||
elif number >= 1e3:
|
||||
k_number = number / 1e3
|
||||
return f"{k_number:.2f}K"
|
||||
else:
|
||||
return str(number)
|
||||
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
bmt.synchronize()
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
bmt.print_rank("finish loading")
|
||||
bmt.print_rank(
|
||||
"Number of parameter {}, Number of non-e parameter {}".format(
|
||||
num_parameters(model), num_non_embedding_parameters(model)
|
||||
)
|
||||
)
|
||||
bmt.print_rank("args: {}".format(args))
|
||||
|
||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,234 @@
|
|||
#!/bin/bash
|
||||
|
||||
#export OMP_NUM_THREADS=16
|
||||
|
||||
declare -A args # Declare an associative array to store arguments and values
|
||||
|
||||
args["model_unique"]="2b_0701"
|
||||
args["resume_ckpt"]=""
|
||||
args["config"]="2.4b"
|
||||
args["flash"]="cuda"
|
||||
args["batch_size"]="1"
|
||||
args["max_length"]="4096"
|
||||
args["save_iters"]="500"
|
||||
args["train_iters"]="10"
|
||||
args["dataset_config"]="fm9g_sft"
|
||||
args["local"]="False"
|
||||
args["dataloader"]="indexed"
|
||||
args["save"]="True"
|
||||
args["dataloader_num_threads"]=1
|
||||
args["dataloader_prefetch"]=1
|
||||
args["dataloader_prefetch_factor"]=1
|
||||
args["dataloader_num_workers"]=1
|
||||
args["lr"]="1e-5"
|
||||
args["warmup_iters"]="20"
|
||||
args["drop_iters"]="0.1"
|
||||
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
||||
args["load_grad"]="False"
|
||||
args["grad_ckpt_num"]="160"
|
||||
args["exp_group"]=""
|
||||
args["ignore_cuda_oom"]="1"
|
||||
args["tensorboard_all_tasks"]="0"
|
||||
args["stop_when_end"]="0"
|
||||
args["only_run_dataloader"]="0"
|
||||
args["eps"]="1e-6"
|
||||
args["inspect_iters"]="100"
|
||||
args["strict_state_dict"]="1"
|
||||
args["only_load_model"]="1"
|
||||
args["lr_scheduler"]="cosine"
|
||||
args["resume_no_optimze"]="0"
|
||||
args["tp_size"]="1"
|
||||
args["parallel_load_datastate"]="8"
|
||||
args["async_save"]="False"
|
||||
args["load_dataloader_ckpt"]="0"
|
||||
args["drop_begin"]="-1"
|
||||
args["drop_rate"]="0.5"
|
||||
args["use_checkpoint"]="0"
|
||||
|
||||
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
||||
while read -r key value; do
|
||||
args["$key"]="$value"
|
||||
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
||||
|
||||
|
||||
|
||||
# 用cmd arg 再更新一次
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Print the values of the arguments
|
||||
echo "----------- CMD args ----------"
|
||||
for key in "${!args[@]}"; do
|
||||
echo "$key: ${args[$key]}"
|
||||
done
|
||||
echo "--------- END CMD args --------"
|
||||
|
||||
|
||||
if [[ ${args["flash"]} == "triton" ]]; then
|
||||
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
||||
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
||||
echo "triton flash"
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
||||
# GPUS_PER_NODE=1
|
||||
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
||||
|
||||
|
||||
if [[ ${args["model_unique"]} == "" ]]; then
|
||||
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
||||
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
||||
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
||||
else
|
||||
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
||||
fi
|
||||
echo "model_unique: "$MODEL_UNIQUE
|
||||
|
||||
# --------------- 运行参数 ---------------
|
||||
|
||||
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
||||
OPTS+=" --batch-size ${args["batch_size"]}"
|
||||
OPTS+=" --train-iters ${args["train_iters"]}"
|
||||
OPTS+=" --save-iters ${args["save_iters"]}"
|
||||
OPTS+=" --save-name fm9g_live_checkpoint"
|
||||
OPTS+=" --max-length ${args["max_length"]}"
|
||||
OPTS+=" --lr ${args["lr"]}"
|
||||
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
||||
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
||||
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
||||
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
||||
OPTS+=" --offload"
|
||||
#OPTS+=" --vocab ./tokenizer/vocab.txt"
|
||||
OPTS+=" --flash ${args["flash"]}"
|
||||
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
||||
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
||||
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
||||
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
||||
OPTS+=" --eps ${args["eps"]}"
|
||||
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
||||
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
||||
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
||||
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --tp-size ${args["tp_size"]}"
|
||||
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
||||
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
||||
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
||||
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
||||
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
||||
|
||||
if [[ ${args["load_grad"]} == "True" ]]; then
|
||||
OPTS+=" --load-grad"
|
||||
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["async_save"]} == "True" ]]; then
|
||||
OPTS+=" --async_save"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
||||
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
||||
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
||||
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
||||
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
||||
fi
|
||||
|
||||
|
||||
# --------------- 写文件路径 ---------------
|
||||
## checkpoint
|
||||
if [[ ${args["save"]} == "True" ]]; then
|
||||
|
||||
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
||||
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
||||
else
|
||||
echo "won't save model"
|
||||
fi
|
||||
|
||||
|
||||
## logs,/local/logs 等价于 ./datalogs(软链)
|
||||
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
||||
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
||||
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
||||
|
||||
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
else
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
||||
fi
|
||||
|
||||
|
||||
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
||||
if [ "${args["resume_ckpt"]}" != "" ]; then
|
||||
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
||||
else
|
||||
echo "No checkpoint to load"
|
||||
fi
|
||||
|
||||
|
||||
filename="pretrain_dragonfly"
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
else
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
fi
|
||||
|
||||
|
||||
GPUS_PER_NODE=1
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ENDPOINT=g3006
|
||||
MASTER_PORT=23456
|
||||
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
|
||||
echo "-------final CMD is------"
|
||||
echo "${CMD}"
|
||||
echo "-------final CMD end------"
|
||||
|
||||
$CMD
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"pretrain": {
|
||||
"train_iters": 1000000000,
|
||||
"batch_size": 1,
|
||||
"max_length": 4096,
|
||||
"n_gpus": 8,
|
||||
"lr": 0.01
|
||||
}
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = "<用户>"+data['input']+"<AI>"
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,20 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = data['input']
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
[
|
||||
{
|
||||
"dataset_name": "humanevallike_clean_dedup",
|
||||
"task_name": "humanevallike_clean_dedup",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 995339,
|
||||
"ave_tokens_per_line": 100,
|
||||
"total_tokens": 0.1
|
||||
},
|
||||
{
|
||||
"dataset_name": "leetcode_pass_code_0125",
|
||||
"task_name": "leetcode_pass_code_0125",
|
||||
"abs_weight": 0.006,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 10724,
|
||||
"ave_tokens_per_line": 200,
|
||||
"total_tokens": 0.002
|
||||
},
|
||||
{
|
||||
"dataset_name": "logiv2Annotate",
|
||||
"task_name": "logiv2Annotate",
|
||||
"abs_weight": 0.004,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 12566,
|
||||
"ave_tokens_per_line": 512,
|
||||
"total_tokens": 0.006
|
||||
},
|
||||
{
|
||||
"dataset_name": "mmlu_enhance",
|
||||
"task_name": "mmlu_enhance",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 169771,
|
||||
"ave_tokens_per_line": 300,
|
||||
"total_tokens": 0.05
|
||||
},
|
||||
{
|
||||
"dataset_name": "mtbench_like",
|
||||
"task_name": "mtbench_like",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 319080,
|
||||
"ave_tokens_per_line": 500,
|
||||
"total_tokens": 0.15
|
||||
},
|
||||
{
|
||||
"dataset_name": "ultra_dataset_new",
|
||||
"task_name": "ultra_dataset_new",
|
||||
"abs_weight": 2.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 385045,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 2.0
|
||||
},
|
||||
{
|
||||
"dataset_name": "sft_data_zh_wowru",
|
||||
"task_name": "sft_data_zh_wowru",
|
||||
"abs_weight": 1.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 1
|
||||
},
|
||||
{
|
||||
"dataset_name": "math_data",
|
||||
"task_name": "math_data",
|
||||
"abs_weight": 0.003,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 0.005
|
||||
},
|
||||
{
|
||||
"dataset_name": "t0",
|
||||
"task_name": "t0",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 1650309,
|
||||
"ave_tokens_per_line": 500.296266559615,
|
||||
"total_tokens": 0.82
|
||||
},
|
||||
{
|
||||
"dataset_name": "wikihow",
|
||||
"task_name": "wikihow",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 180128,
|
||||
"ave_tokens_per_line": 900.296266559615,
|
||||
"total_tokens": 0.16
|
||||
},
|
||||
{
|
||||
"dataset_name": "reclor",
|
||||
"task_name": "reclor",
|
||||
"abs_weight": 0.002,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 4174,
|
||||
"ave_tokens_per_line": 700.296266559615,
|
||||
"total_tokens": 0.003
|
||||
},
|
||||
{
|
||||
"dataset_name": "logic_test_lx_0127",
|
||||
"task_name": "logic_test_lx_0127",
|
||||
"abs_weight": 0.001,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2800,
|
||||
"ave_tokens_per_line": 200.96266559615,
|
||||
"total_tokens": 0.0004
|
||||
}
|
||||
]
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"vocab_size": 119696,
|
||||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"half_type": "bf16",
|
||||
"use_flash_attn": true,
|
||||
"flash_attn_mask_shape": "2d",
|
||||
"dim_model": 4096,
|
||||
"dim_ff": 14336,
|
||||
"dim_head": 128,
|
||||
"num_heads": 32,
|
||||
"num_kv_heads": 32,
|
||||
"num_layers": 32,
|
||||
"activate_fn": "silu",
|
||||
"init_std": 0.10,
|
||||
"scale": false,
|
||||
"scale_emb": 12,
|
||||
"scale_depth": -1,
|
||||
"model_type": "fm9g",
|
||||
"architectures": [
|
||||
"FM9GForCausalLM"
|
||||
],
|
||||
"qk_norm": false,
|
||||
"tie_lm_head": false,
|
||||
"ffn_gated": true
|
||||
}
|
|
@ -0,0 +1,568 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 ModelBest Inc.
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from bmtrain import nccl
|
||||
from bmtrain.global_var import config as bmt_config
|
||||
|
||||
sys.path.append("../../")
|
||||
from fm9g.arguments import get_args
|
||||
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
||||
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import CudaPrefetcher
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import MixedIndexedDataset
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import UnpadBatchedMixedDataset
|
||||
from fm9g.utils import exporter
|
||||
from fm9g.utils import logger
|
||||
from fm9g.utils.exporter import save_every_step_stats
|
||||
from fm9g.utils.training_stats import num_non_embedding_parameters
|
||||
from fm9g.utils.training_stats import num_parameters
|
||||
|
||||
|
||||
def get_tokenizer(args):
|
||||
from fm9g.tokenizer import FM9GTokenizer
|
||||
tokenizer = FM9GTokenizer(path=args.vocab)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(args):
|
||||
config = DragonflyConfig.from_json_file(args.model_config)
|
||||
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
||||
config.pose_prob = args.pose_prob
|
||||
config.pose_scaling_factor = args.pose_scaling_factor
|
||||
config.rope_scaling_type = args.rope_scaling_type
|
||||
config.rope_scaling_factor = args.rope_scaling_factor
|
||||
config.orig_max_length = args.orig_max_length
|
||||
config.use_checkpoint = True if args.use_checkpoint == 1 else False
|
||||
|
||||
bmt.print_rank("model config: {}".format(config))
|
||||
bmt.print_rank("bmt config: {}".format(bmt.config))
|
||||
|
||||
model = Dragonfly(config)
|
||||
if args.load is not None:
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||
exporter.load_model_ckpt(args, model)
|
||||
else:
|
||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||
bmt.init_parameters(model)
|
||||
return model
|
||||
|
||||
|
||||
def get_optimizer(args, model):
|
||||
scale_lr_group = []
|
||||
normal_group = []
|
||||
scale_lr_group_name, normal_group_name = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
||||
scale_lr_group.append(p)
|
||||
scale_lr_group_name.append(n)
|
||||
else:
|
||||
normal_group.append(p)
|
||||
normal_group_name.append(n)
|
||||
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
||||
param_groups = [
|
||||
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
||||
{"params": normal_group, "lr": args.lr},
|
||||
]
|
||||
|
||||
if args.offload:
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
else:
|
||||
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
if args.load is not None and args.load_grad:
|
||||
exporter.load_optimizer_ckpt(args, optimizer)
|
||||
bmt.print_rank("optimizer is loaded!")
|
||||
return optimizer
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
from fm9g.training_utils.lr_scheduler import Cosine
|
||||
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
||||
from fm9g.training_utils.lr_scheduler import WarmupStableExp
|
||||
|
||||
end_iter = args.train_iters
|
||||
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
||||
warmup_iters = int(end_iter * args.warmup_iters)
|
||||
else:
|
||||
warmup_iters = int(args.warmup_iters)
|
||||
|
||||
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
||||
drop_iters = int(end_iter * args.drop_iters)
|
||||
else:
|
||||
drop_iters = int(args.drop_iters)
|
||||
|
||||
if args.lr_scheduler == "cosine":
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
# lr_end_restart=args.lr_end_restart,
|
||||
# resume_no_optimze=args.resume_no_optimze,
|
||||
#)
|
||||
elif args.lr_scheduler == "warmupstabledrop":
|
||||
lr_scheduler = WarmupStableDrop(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
drop_iter=drop_iters,
|
||||
num_iter=args.start_step,
|
||||
resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
elif args.lr_scheduler == "warmupstableexp":
|
||||
lr_scheduler = WarmupStableExp(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
drop_begin=args.drop_begin, # 原来是lr_decay_iter
|
||||
drop_iter=drop_iters,
|
||||
drop_rate=args.drop_rate,
|
||||
num_iter=args.start_step,
|
||||
resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
start = time.time()
|
||||
tokenizer = get_tokenizer(args)
|
||||
bmt.synchronize()
|
||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
model = get_model(args)
|
||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
optimizer = get_optimizer(args, model)
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def resume_training(args):
|
||||
ckpts = sorted(
|
||||
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
||||
)
|
||||
# find newest job
|
||||
ckpts = sorted(
|
||||
ckpts,
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
||||
)
|
||||
|
||||
if len(ckpts) > 0:
|
||||
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
||||
args.load = ckpts[0]
|
||||
# by default, do not load grad file
|
||||
args.load_grad = False
|
||||
args.start_step = 0
|
||||
else:
|
||||
# no ckpts, nothing we can do
|
||||
os._exit(1)
|
||||
|
||||
|
||||
def initialize():
|
||||
args = get_args(pretrain=True)
|
||||
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
||||
|
||||
if args.save is not None:
|
||||
os.makedirs(args.save, exist_ok=True)
|
||||
if args.load is not None:
|
||||
if args.only_load_model == 0:
|
||||
if args.start_step == 0:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
if "iteration" in log_ckpt:
|
||||
args.start_step = log_ckpt["iteration"]
|
||||
else:
|
||||
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
||||
logger.info("Start from step {}".format(args.start_step))
|
||||
elif args.only_load_model == 1:
|
||||
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
logger.info("You do not load model")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def see_memory(detail=False):
|
||||
if detail:
|
||||
res = torch.cuda.memory_summary()
|
||||
else:
|
||||
res = (
|
||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return res
|
||||
|
||||
|
||||
def add_mem_time(info, mem_usage, tim_usage):
|
||||
torch.cuda.synchronize()
|
||||
bmt.synchronize()
|
||||
mem_usage[info] = see_memory()
|
||||
tim_usage[info] = time.time()
|
||||
return mem_usage, tim_usage
|
||||
|
||||
|
||||
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
||||
# task_ids 可能有-1 来代表无效token
|
||||
_task_num = task_num + 1
|
||||
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
||||
# gen masks
|
||||
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
||||
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
||||
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
||||
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
||||
# calc loss and tokens
|
||||
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
# return token-wise avg losses and tokens
|
||||
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
||||
|
||||
|
||||
class ChunkAve:
|
||||
def __init__(self, chunk_size=100):
|
||||
self.ave_list = []
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def record(self, time):
|
||||
self.ave_list.append(time)
|
||||
self.ave_list = self.ave_list[-self.chunk_size :]
|
||||
|
||||
def get(self):
|
||||
return sum(self.ave_list) / len(self.ave_list)
|
||||
|
||||
|
||||
def pretrain(
|
||||
args,
|
||||
tokenizer,
|
||||
model: Dragonfly,
|
||||
optimizer,
|
||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||
):
|
||||
ave_model_time = ChunkAve(chunk_size=100)
|
||||
ave_iter_time = ChunkAve(chunk_size=100)
|
||||
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||
optim_manager = bmt.optim.OptimManager(
|
||||
loss_scale=bmt.world_size(),
|
||||
loss_scale_steps=args.loss_scale_steps,
|
||||
loss_scale_factor=2,
|
||||
max_loss_scale=bmt.world_size(),
|
||||
min_loss_scale=bmt.world_size(),
|
||||
)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
|
||||
start_step = args.start_step
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
import distutils.version # noqa: F401
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
if not os.path.exists(args.tensorboard):
|
||||
os.makedirs(args.tensorboard)
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
|
||||
if args.load is not None:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
else:
|
||||
log_ckpt = {}
|
||||
|
||||
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
||||
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
||||
|
||||
global_world_size = bmt.world_size()
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset = MixedIndexedDataset(
|
||||
cfg_path=args.dataset,
|
||||
cfg_json_str=None,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
nthreads=args.dataloader_num_threads,
|
||||
prefetch_slice=args.dataloader_prefetch,
|
||||
weight_by_size=True,
|
||||
)
|
||||
|
||||
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
||||
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
||||
|
||||
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
batched_dataset,
|
||||
batch_size=None,
|
||||
collate_fn=lambda x: x,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
prefetch_factor=args.dataloader_prefetch_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
def dummy_generator():
|
||||
while True:
|
||||
yield None
|
||||
|
||||
mixed_indexed_dataset = dummy_generator()
|
||||
dataloader = mixed_indexed_dataset
|
||||
|
||||
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
||||
|
||||
bmt.print_rank("Preparing dataset done.")
|
||||
|
||||
# inspect at init
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
try:
|
||||
mem_usage, tim_usage = {}, {}
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
||||
|
||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||
|
||||
logits = model(
|
||||
input=data["inputs"],
|
||||
cu_seqlens=data["cu_seqlens"],
|
||||
max_seqlen=data["max_seqlen"],
|
||||
position_ids=data["position_ids"],
|
||||
)
|
||||
#print("logits: ", logits)
|
||||
|
||||
# chunk targets and task_ids
|
||||
data["targets"] = (
|
||||
data["targets"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["targets"].shape[0], -1)
|
||||
)
|
||||
data["task_ids"] = (
|
||||
data["task_ids"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["task_ids"].shape[0], -1)
|
||||
)
|
||||
|
||||
_target = data["targets"].view(-1)
|
||||
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
||||
_w = (_target != -100).int()
|
||||
loss = non_reduced_loss.sum() / _w.sum().float()
|
||||
|
||||
global_loss = bmt.sum_loss(loss).item()
|
||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||
|
||||
optim_manager.backward(loss)
|
||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||
|
||||
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
||||
grad_accum_init_time = tim_usage["init"]
|
||||
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
optim_manager.zero_grad()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
model_time = tim_usage["optim"] - grad_accum_init_time
|
||||
ave_model_time.record(model_time)
|
||||
else:
|
||||
# dummy optim step
|
||||
grad_norm = torch.Tensor([0.0]).cuda()
|
||||
tim_usage["optim"] = tim_usage["backward"]
|
||||
mem_usage["optim"] = mem_usage["backward"]
|
||||
model_time = tim_usage["optim"] - tim_usage['init']
|
||||
|
||||
with torch.no_grad():
|
||||
task_num = len(data["task_names"])
|
||||
task_loss, task_token = get_task_loss_and_token(
|
||||
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
||||
)
|
||||
task_loss_map: Dict[str, float] = {}
|
||||
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
||||
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
||||
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
||||
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
||||
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
||||
ave_task_loss = sum_task_loss / tot_task_token
|
||||
for i in range(task_num):
|
||||
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
||||
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
||||
|
||||
local_total_rate = torch.Tensor(
|
||||
[data["lengths"].float().mean() / (args.max_length * args.batch_size)]
|
||||
).cuda()
|
||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||
global_token_pass += (
|
||||
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
)
|
||||
last_before_log_time = tim_usage["before_log"]
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
iter_time = tim_usage["before_log"] - last_before_log_time
|
||||
|
||||
ave_iter_time.record(iter_time)
|
||||
|
||||
train_info = {
|
||||
"time": iter_time,
|
||||
"iteration": iteration,
|
||||
"loss": global_loss,
|
||||
"lr": lr_scheduler.current_lr,
|
||||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"task_loss": task_loss_map,
|
||||
"total_task_token": global_total_task_token,
|
||||
}
|
||||
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
||||
|
||||
time_report_str = "{model_time:.2f}={forward_time:.2f}+{backward_time:.2f}+{optim_time:.2f}".format(model_time=model_time, forward_time=tim_usage['forward']-tim_usage['init'], backward_time=tim_usage['backward']-tim_usage['forward'], optim_time=tim_usage['optim'] - tim_usage['backward'])
|
||||
bmt.print_rank(
|
||||
(
|
||||
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
||||
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
||||
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
||||
).format(
|
||||
iteration=iteration,
|
||||
loss=global_loss,
|
||||
lr=lr_scheduler.current_lr,
|
||||
model_time=time_report_str,
|
||||
iter_time=iter_time,
|
||||
chunk_ave_time=ave_iter_time.get(),
|
||||
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
||||
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
||||
grad_norm=grad_norm.item(),
|
||||
global_token_pass=global_token_pass_str,
|
||||
mem_usage=max([value for key, value in mem_usage.items()]),
|
||||
)
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"task_loss:\t| "
|
||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % 10 == 0:
|
||||
bmt.print_rank(
|
||||
"task_tokens (B):\t| "
|
||||
+ " | ".join(
|
||||
[
|
||||
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
||||
for task_name, task_token in global_total_task_token.items()
|
||||
]
|
||||
)
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % args.inspect_iters == 0:
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
if args.log_dir is not None and bmt.rank() == 0:
|
||||
if args.save is not None:
|
||||
save_every_step_stats(train_info, args.save)
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||
for task_name, loss in task_loss_map.items():
|
||||
if not math.isnan(loss):
|
||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||
|
||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||
log_ckpt = {
|
||||
"global_total_task_token": global_total_task_token,
|
||||
"global_token_pass": global_token_pass,
|
||||
"iteration": iteration,
|
||||
}
|
||||
|
||||
if args.save is not None and iteration % args.save_iters == 0:
|
||||
exporter.export(
|
||||
model,
|
||||
mixed_indexed_dataset,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
iteration,
|
||||
args,
|
||||
log_ckpt=log_ckpt,
|
||||
final_save=False,
|
||||
async_save=args.async_save,
|
||||
)
|
||||
|
||||
if iteration == args.train_iters and args.stop_when_end == 1:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"train loop err: {e}")
|
||||
raise e
|
||||
finally:
|
||||
pass
|
||||
|
||||
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
||||
|
||||
|
||||
def convert_to_k_and_b(number):
|
||||
if number >= 1e9: # 大于或等于10亿
|
||||
b_number = number / 1e9
|
||||
return f"{b_number:.2f}B"
|
||||
elif number >= 1e6: # 大于或等于1百万
|
||||
k_number = number / 1e6
|
||||
return f"{k_number:.2f}M"
|
||||
elif number >= 1e3:
|
||||
k_number = number / 1e3
|
||||
return f"{k_number:.2f}K"
|
||||
else:
|
||||
return str(number)
|
||||
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
bmt.synchronize()
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
bmt.print_rank("finish loading")
|
||||
bmt.print_rank(
|
||||
"Number of parameter {}, Number of non-e parameter {}".format(
|
||||
num_parameters(model), num_non_embedding_parameters(model)
|
||||
)
|
||||
)
|
||||
bmt.print_rank("args: {}".format(args))
|
||||
|
||||
print("begining training")
|
||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,234 @@
|
|||
#!/bin/bash
|
||||
|
||||
#export OMP_NUM_THREADS=16
|
||||
|
||||
declare -A args # Declare an associative array to store arguments and values
|
||||
|
||||
args["model_unique"]="8b_0702"
|
||||
args["resume_ckpt"]=""
|
||||
args["config"]="8b"
|
||||
args["flash"]="cuda"
|
||||
args["batch_size"]="1"
|
||||
args["max_length"]="4096"
|
||||
args["save_iters"]="500"
|
||||
args["train_iters"]="10"
|
||||
args["dataset_config"]="fm9g_sft"
|
||||
args["local"]="False"
|
||||
args["dataloader"]="indexed"
|
||||
args["save"]="True"
|
||||
args["dataloader_num_threads"]=1
|
||||
args["dataloader_prefetch"]=2
|
||||
args["dataloader_prefetch_factor"]=32
|
||||
args["dataloader_num_workers"]=2
|
||||
args["lr"]="1e-5"
|
||||
args["warmup_iters"]="20"
|
||||
args["drop_iters"]="0.1"
|
||||
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
||||
args["load_grad"]="False"
|
||||
args["grad_ckpt_num"]="160"
|
||||
args["exp_group"]=""
|
||||
args["ignore_cuda_oom"]="1"
|
||||
args["tensorboard_all_tasks"]="0"
|
||||
args["stop_when_end"]="0"
|
||||
args["only_run_dataloader"]="0"
|
||||
args["eps"]="1e-6"
|
||||
args["inspect_iters"]="100"
|
||||
args["strict_state_dict"]="1"
|
||||
args["only_load_model"]="1"
|
||||
args["lr_scheduler"]="cosine"
|
||||
args["resume_no_optimze"]="0"
|
||||
args["tp_size"]="1"
|
||||
args["parallel_load_datastate"]="16"
|
||||
args["async_save"]="False"
|
||||
args["load_dataloader_ckpt"]="0"
|
||||
args["drop_begin"]="-1"
|
||||
args["drop_rate"]="0.5"
|
||||
args["use_checkpoint"]="1"
|
||||
|
||||
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
||||
while read -r key value; do
|
||||
args["$key"]="$value"
|
||||
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
||||
|
||||
|
||||
|
||||
# 用cmd arg 再更新一次
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Print the values of the arguments
|
||||
echo "----------- CMD args ----------"
|
||||
for key in "${!args[@]}"; do
|
||||
echo "$key: ${args[$key]}"
|
||||
done
|
||||
echo "--------- END CMD args --------"
|
||||
|
||||
|
||||
if [[ ${args["flash"]} == "triton" ]]; then
|
||||
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
||||
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
||||
echo "triton flash"
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
||||
# GPUS_PER_NODE=1
|
||||
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
||||
|
||||
|
||||
if [[ ${args["model_unique"]} == "" ]]; then
|
||||
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
||||
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
||||
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
||||
else
|
||||
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
||||
fi
|
||||
echo "model_unique: "$MODEL_UNIQUE
|
||||
|
||||
# --------------- 运行参数 ---------------
|
||||
|
||||
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
||||
OPTS+=" --batch-size ${args["batch_size"]}"
|
||||
OPTS+=" --train-iters ${args["train_iters"]}"
|
||||
OPTS+=" --save-iters ${args["save_iters"]}"
|
||||
OPTS+=" --save-name fm9g_live_checkpoint"
|
||||
OPTS+=" --max-length ${args["max_length"]}"
|
||||
OPTS+=" --lr ${args["lr"]}"
|
||||
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
||||
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
||||
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
||||
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
||||
OPTS+=" --offload"
|
||||
OPTS+=" --vocab ./tokenizer/vocab.txt"
|
||||
OPTS+=" --flash ${args["flash"]}"
|
||||
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
||||
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
||||
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
||||
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
||||
OPTS+=" --eps ${args["eps"]}"
|
||||
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
||||
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
||||
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
||||
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --tp-size ${args["tp_size"]}"
|
||||
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
||||
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
||||
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
||||
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
||||
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
||||
|
||||
if [[ ${args["load_grad"]} == "True" ]]; then
|
||||
OPTS+=" --load-grad"
|
||||
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["async_save"]} == "True" ]]; then
|
||||
OPTS+=" --async_save"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
||||
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
||||
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
||||
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
||||
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
||||
fi
|
||||
|
||||
|
||||
# --------------- 写文件路径 ---------------
|
||||
## checkpoint
|
||||
if [[ ${args["save"]} == "True" ]]; then
|
||||
|
||||
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
||||
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
||||
else
|
||||
echo "won't save model"
|
||||
fi
|
||||
|
||||
|
||||
## logs,/local/logs 等价于 ./datalogs(软链)
|
||||
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
||||
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
||||
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
||||
|
||||
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
else
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
||||
fi
|
||||
|
||||
|
||||
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
||||
if [ "${args["resume_ckpt"]}" != "" ]; then
|
||||
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
||||
else
|
||||
echo "No checkpoint to load"
|
||||
fi
|
||||
|
||||
|
||||
filename="pretrain_dragonfly"
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
else
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
fi
|
||||
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ENDPOINT=g3006
|
||||
MASTER_PORT=12345
|
||||
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
|
||||
echo "-------final CMD is------"
|
||||
echo "${CMD}"
|
||||
echo "-------final CMD end------"
|
||||
|
||||
$CMD
|
|
@ -119687,8 +119687,8 @@
|
|||
"𠳐"
|
||||
"𥻗"
|
||||
"𬉼"
|
||||
"<pad_0>"
|
||||
"<pad_1>"
|
||||
"<|im_start|>"
|
||||
"<|im_end|>"
|
||||
"<pad_2>"
|
||||
"<pad_3>"
|
||||
"<pad_4>"
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"pretrain": {
|
||||
"train_iters": 20000,
|
||||
"batch_size": 1,
|
||||
"max_length": 4096,
|
||||
"n_gpus": 8,
|
||||
"lr": 1e-5
|
||||
}
|
||||
}
|
|
@ -1,3 +1,18 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
|
@ -7,6 +22,8 @@ def add_model_config_args(parser: argparse.ArgumentParser):
|
|||
group = parser.add_argument_group("model", "model configuration")
|
||||
group.add_argument("--model-config", type=str, help="model configuration file")
|
||||
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
|
||||
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
|
||||
# group.add_argument("--qk_norm", action="store_true", default=False, help="qk layernorm")
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -31,6 +48,13 @@ def add_training_args(parser: argparse.ArgumentParser):
|
|||
help="Load the gradient states",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--grad-ckpt-num",
|
||||
type=int,
|
||||
default=0,
|
||||
help="grad file num (only work when --load-grad from files less than world-size )",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--load-start-step",
|
||||
action="store_true",
|
||||
|
@ -66,7 +90,9 @@ def add_training_args(parser: argparse.ArgumentParser):
|
|||
|
||||
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
|
||||
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
|
||||
group.add_argument("--num-micro-batches", type=int, default=16)
|
||||
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
|
||||
group.add_argument("--grad-accum", type=int, default=1, help="gradient accum steps")
|
||||
group.add_argument(
|
||||
"--train-iters",
|
||||
type=int,
|
||||
|
@ -74,11 +100,14 @@ def add_training_args(parser: argparse.ArgumentParser):
|
|||
help="total number of iterations to train over all training runs",
|
||||
)
|
||||
group.add_argument("--max-length", type=int, default=512, help="max length of input")
|
||||
group.add_argument("--min-length", type=int, default=None, help="only for speed test")
|
||||
|
||||
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
|
||||
|
||||
# Learning rate.
|
||||
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
|
||||
group.add_argument("--lr_scheduler", type=str, default="cosine", help=" learning rate scheduler")
|
||||
|
||||
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
|
||||
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
|
||||
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
|
||||
|
@ -92,21 +121,85 @@ def add_training_args(parser: argparse.ArgumentParser):
|
|||
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lr-decay-style",
|
||||
type=str,
|
||||
default="noam",
|
||||
choices=["constant", "linear", "cosine", "exponential", "noam"],
|
||||
help="learning rate decay function",
|
||||
"--drop-iters",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
||||
)
|
||||
|
||||
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
|
||||
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
|
||||
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
|
||||
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
|
||||
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
|
||||
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
|
||||
group.add_argument("--tp", default=1, type=int, help="whether we use tensor parallelism")
|
||||
group.add_argument("--use-jfs-data", action="store_true", help="whether we use juicefs dataset")
|
||||
group.add_argument("--tp-size", default=1, type=int)
|
||||
group.add_argument("--pp-size", default=1, type=int)
|
||||
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
|
||||
group.add_argument("--gradient-accumulation-steps", type=int, default=1, help="gradient accumulation steps")
|
||||
group.add_argument("--dataloader_num_threads", default=3, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument("--dataloader_prefetch", default=200, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument("--dataloader_num_workers", default=4, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument("--dataloader_prefetch_factor", default=50, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument(
|
||||
"--dataloader",
|
||||
default="indexed",
|
||||
type=str,
|
||||
help="dataloader type, 'indexed' for indexed dataset, 'normal' for normal dataset",
|
||||
)
|
||||
group.add_argument("--stop_when_end", default=0, type=int, help="Whether to stop training when we reach end_iter")
|
||||
group.add_argument(
|
||||
"--data_len_threshold",
|
||||
default=512,
|
||||
type=int,
|
||||
help="If the average length of a sequence is less than this int, mean the sample is biased. ",
|
||||
)
|
||||
group.add_argument(
|
||||
"--only_run_dataloader", default=0, type=int, help="Whether to only run dataloader to check data. "
|
||||
)
|
||||
group.add_argument(
|
||||
"--only_load_model", default=0, type=int, help="Whether to only load a model ckpt, without anything else."
|
||||
)
|
||||
group.add_argument(
|
||||
"--load_dataloader_ckpt", default=1, type=int, help="Whether to only load a model ckpt, without anything else."
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--resume_no_optimze",
|
||||
default=0,
|
||||
type=int,
|
||||
help="The number of steps that does not add optimization after resume",
|
||||
)
|
||||
group.add_argument(
|
||||
"--parallel_load_datastate",
|
||||
default=256,
|
||||
type=int,
|
||||
help="The number of parallel workers to load dataset state",
|
||||
)
|
||||
group.add_argument(
|
||||
"--async_save",
|
||||
action="store_true",
|
||||
help="whether to save artifacts asynchronously",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop_begin",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="The number of steps that starts to drop lr"
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop_rate",
|
||||
default=0.5,
|
||||
type=float,
|
||||
help="The number rate"
|
||||
)
|
||||
group.add_argument(
|
||||
"--use_checkpoint",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Whether to use checkpointing."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -133,6 +226,17 @@ def add_pretrain_args(parser: argparse.ArgumentParser):
|
|||
return parser
|
||||
|
||||
|
||||
def add_tokenizer_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("tokenizer", "tokenizer configurations")
|
||||
group.add_argument(
|
||||
"--tokenizer_path",
|
||||
type=str,
|
||||
default="",
|
||||
help="tokenizer_path",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_finetune_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("finetune", "finetune configurations")
|
||||
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
|
||||
|
@ -204,14 +308,53 @@ def add_feedback_learning_args(parser: argparse.ArgumentParser):
|
|||
return parser
|
||||
|
||||
|
||||
def add_delta_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("LoRA","LoRA configurations")
|
||||
group.add_argument("--delta-type", type=str, default=None, help="delta-tuning-type")
|
||||
group.add_argument("--lora-r", type=int, default=8, help="lora-rank")
|
||||
group.add_argument("--lora-alpha", type=int, default=8, help="lora-alpha")
|
||||
group.add_argument("--lora-dropout", type=float, default=0.0, help="lora-dropout")
|
||||
group.add_argument("--lora-layer", nargs='+', default=['project_q','project_k'], help="lora-layer")
|
||||
group.add_argument("--save-origin-model", action="store_true", default=False)
|
||||
def add_model_change_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("model_change", "model change during pretraining")
|
||||
group.add_argument("--strict_state_dict", type=int, default=1, help="strict_state_dict")
|
||||
##
|
||||
return parser
|
||||
|
||||
|
||||
def add_log_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("log", "log configurations")
|
||||
group.add_argument("--tensorboard_all_tasks", type=int, default=0, help="log")
|
||||
return parser
|
||||
|
||||
|
||||
def add_error_handle_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("error_handle", "error_handle configurations")
|
||||
group.add_argument(
|
||||
"--ignore_cuda_oom", type=int, default=1, help="continue training by ingore the batch that causes oom"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_runtime_eval_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("runtime eval args", "runtime evaluation by submitting a job")
|
||||
group.add_argument(
|
||||
"--runtime_eval",
|
||||
action="store_true",
|
||||
help="whether to use runtime_eval. Only if this is set to True, the following variables will be useful",
|
||||
)
|
||||
group.add_argument("--eval_jeeves_auth", type=str, default="", help="auth, press f12 on jeeves platform to get")
|
||||
group.add_argument("--eval_project_id", type=str, default=None, help="project id")
|
||||
group.add_argument("--eval_run_cmd", type=str, default="", help="cmd for eval")
|
||||
group.add_argument(
|
||||
"--eval_git_path",
|
||||
type=str,
|
||||
default="git@git.in.zhihu.com:luca/llm-bench.git",
|
||||
help="git path of evaluation code",
|
||||
)
|
||||
group.add_argument("--eval_git_branch", type=str, default="master", help="git branch of evaluation code")
|
||||
group.add_argument("--eval_node_num", type=int, default=1, help="using 1 node to evaluate")
|
||||
group.add_argument("--eval_gpu_num", type=int, default=1, help="using 1 gpu per node to evaluate")
|
||||
group.add_argument("--eval_tasks_config", type=str, default="", help="evaluate tasks' config")
|
||||
group.add_argument("--eval_model_backend", default="torch", type=str, help="model_backend")
|
||||
|
||||
group.add_argument(
|
||||
"--eval_at_start", action="store_true", help="whether to eval at the first epoch, default to false"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -222,6 +365,30 @@ def add_reward_args(parser: argparse.ArgumentParser):
|
|||
return parser
|
||||
|
||||
|
||||
def add_long_context_extend_args(parser: argparse.ArgumentParser):
|
||||
"""long context extending arguments."""
|
||||
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
|
||||
group.add_argument("--pose-prob", default=0.0, type=float, help="Sample-level PoSE probability")
|
||||
group.add_argument(
|
||||
"--pose-scaling-factor",
|
||||
default=1.0,
|
||||
type=float,
|
||||
help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor",
|
||||
)
|
||||
group.add_argument(
|
||||
"--rope-scaling-type",
|
||||
default="",
|
||||
type=str,
|
||||
choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],
|
||||
help="Context scaling type",
|
||||
)
|
||||
group.add_argument("--rope-scaling-factor", default=1, type=int, help="Context scaling factor")
|
||||
group.add_argument(
|
||||
"--orig-max-length", default=8192, type=int, help="Original context length before context extending"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_args(
|
||||
pretrain: bool = False,
|
||||
finetune: bool = False,
|
||||
|
@ -235,9 +402,14 @@ def get_args(
|
|||
parser = add_training_args(parser)
|
||||
if pretrain:
|
||||
parser = add_pretrain_args(parser)
|
||||
parser = add_runtime_eval_args(parser)
|
||||
parser = add_tokenizer_args(parser)
|
||||
parser = add_log_args(parser)
|
||||
parser = add_error_handle_args(parser)
|
||||
parser = add_model_change_args(parser)
|
||||
|
||||
if finetune:
|
||||
parser = add_finetune_args(parser)
|
||||
parser = add_delta_args(parser)
|
||||
if rhlf:
|
||||
parser = add_rhlf_args(parser)
|
||||
if simple_rlhf:
|
||||
|
@ -246,6 +418,7 @@ def get_args(
|
|||
parser = add_feedback_learning_args(parser)
|
||||
if reward:
|
||||
parser = add_reward_args(parser)
|
||||
parser = add_long_context_extend_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -4,7 +4,8 @@ from .distributed_dataset import SimpleDataset
|
|||
from .indexed_dataset import IndexedDataset
|
||||
from .indexed_dataset import IndexedDatasetBuilder
|
||||
from .indexed_dataset import PrefetchDecodeDataset
|
||||
from .list_dataset import ListDataset
|
||||
|
||||
# from .list_dataset import ListDataset
|
||||
from .utils import compact_dataset
|
||||
from .utils import CudaPrefetcher
|
||||
from .utils import mask_dataset
|
|
@ -1,3 +1,18 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
import io
|
||||
import json
|
||||
|
@ -281,7 +296,6 @@ class DistributedDataset:
|
|||
info: List[FileInfo] = []
|
||||
if os.path.exists(meta_path):
|
||||
info = _read_info_list(meta_path)
|
||||
|
||||
old_len = len(self._file_info)
|
||||
if old_len > len(info):
|
||||
raise RuntimeError("Dataset meta file: changed unexpectly")
|
||||
|
@ -443,7 +457,11 @@ class DistributedDataset:
|
|||
with torch.no_grad():
|
||||
if self._world_size > 1:
|
||||
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
||||
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
|
||||
max_unused_blocks = (
|
||||
bmt.distributed.all_reduce(gpu_num_unused_block, op="max", comm=bmt.config["tp_zero_comm"])
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
||||
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
|
||||
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
|
||||
|
@ -452,9 +470,15 @@ class DistributedDataset:
|
|||
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
||||
dtype=torch.long,
|
||||
).cuda()
|
||||
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
|
||||
global_offset = bmt.distributed.all_gather(gpu_offset).cpu() # (world_size, max_unused_blocks)
|
||||
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
|
||||
global_states = bmt.distributed.all_gather(
|
||||
gpu_states, comm=bmt.config["tp_zero_comm"]
|
||||
).cpu() # (world_size, max_unused_blocks)
|
||||
global_offset = bmt.distributed.all_gather(
|
||||
gpu_offset, comm=bmt.config["tp_zero_comm"]
|
||||
).cpu() # (world_size, max_unused_blocks)
|
||||
global_block = bmt.distributed.all_gather(
|
||||
gpu_block, comm=bmt.config["tp_zero_comm"]
|
||||
).cpu() # (world_size, 4)
|
||||
return {"states": global_states, "offset": global_offset, "block": global_block}
|
||||
else:
|
||||
return {
|
|
@ -1,13 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
"""
|
||||
使用 IndexedDataset 前需按指定格式构建或者转换已有数据集
|
||||
数据集文件结构:
|
||||
- <dataset name>
|
||||
- data.jsonl # jsonl 格式的数据,每一行一条样本
|
||||
- index # 记录每一行 json 数据的起始 byte-offset
|
||||
|
||||
从头构建:直接使用 IndexedDatasetBuilder 这个 context manager
|
||||
>>> with IndexedDatasetBuilder("swear", overwrite=True) as builder:
|
||||
>>> for data in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
|
||||
>>> builder.put(data)
|
||||
转换:
|
||||
从 fm9g distributed_dataset 转换:使用 `fm9g.dataset.tools.distributed_to_indexed`
|
||||
$ python -m fm9g.dataset.tools.distributed_to_indexed -i <原数据集文件夹> -o <新数据集文件夹>
|
||||
已有 jsonl 数据:使用 `fm9g.dataset.tools.jsonl_to_index` 构建 index 文件。需提前先把 jsonl 文件命名为
|
||||
$ python -m fm9g.dataset.tools.jsonl_to_index -p <数据集文件夹路径>
|
||||
"""
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
|
||||
import bmtrain as bmt
|
||||
import h5py
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import msgspec
|
||||
|
@ -22,13 +50,83 @@ except ModuleNotFoundError:
|
|||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from .utils import Range
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.bitset import bitset_diff
|
||||
|
||||
print_lock = threading.Lock()
|
||||
|
||||
|
||||
def random_range(start, stop=None, step=None):
|
||||
"""
|
||||
Generator of non-repeated random permutation with the same inteface of python
|
||||
`range`. Obtained from https://stackoverflow.com/a/53551417
|
||||
The random.shuffle(list) and random.sample(list, len(list)) require
|
||||
materialize the lists, which result in a long initalization period.
|
||||
"""
|
||||
if stop is None:
|
||||
start, stop = 0, start
|
||||
if step is None:
|
||||
step = 1
|
||||
# Use a mapping to convert a standard range into the desired range.
|
||||
mapping = lambda i: (i * step) + start
|
||||
# Compute the number of numbers in this range.
|
||||
maximum = int(math.ceil((stop - start) / step))
|
||||
if maximum == 0:
|
||||
# early return with empty range
|
||||
yield from ()
|
||||
return
|
||||
# Seed range with a random integer.
|
||||
value = random.randint(0, maximum)
|
||||
# Construct an offset, multiplier, and modulus for a linear
|
||||
# congruential generator. These generators are cyclic and
|
||||
# non-repeating when they maintain the properties:
|
||||
#
|
||||
# 1) "modulus" and "offset" are relatively prime.
|
||||
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
|
||||
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
|
||||
|
||||
# Pick a random odd-valued offset.
|
||||
offset = random.randint(0, maximum) * 2 + 1
|
||||
# Pick a multiplier 1 greater than a multiple of 4.
|
||||
multiplier = 4 * (maximum // 4) + 1
|
||||
# Pick a modulus just big enough to generate all numbers (power of 2).
|
||||
modulus = int(2 ** math.ceil(math.log2(maximum)))
|
||||
# Track how many random numbers have been returned.
|
||||
found = 0
|
||||
while found < maximum:
|
||||
# If this is a valid value, yield it in generator fashion.
|
||||
if value < maximum:
|
||||
found += 1
|
||||
yield mapping(value)
|
||||
# Calculate the next value in the sequence.
|
||||
value = (value * multiplier + offset) % modulus
|
||||
|
||||
|
||||
class Range(object):
|
||||
def __init__(self, start, stop, step):
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def __repr__(self):
|
||||
return f"Range({self.start}, {self.stop}, {self.step})"
|
||||
|
||||
def iterate(self):
|
||||
yield from range(self.start, self.stop, self.step)
|
||||
|
||||
def list(self):
|
||||
return list(range(self.start, self.stop, self.step))
|
||||
|
||||
def subrange(self, split, nsplits):
|
||||
# strided spliting range params
|
||||
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
|
||||
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
|
||||
|
||||
def random_iterate(self):
|
||||
yield from random_range(self.start, self.stop, self.step)
|
||||
|
||||
|
||||
def safe_print(*args, **kargs):
|
||||
if "flush" in kargs:
|
||||
flush = kargs["flush"]
|
||||
|
@ -40,12 +138,15 @@ def safe_print(*args, **kargs):
|
|||
|
||||
|
||||
def concurrent_info():
|
||||
world_size, rank = bmt.world_size(), bmt.rank()
|
||||
# world_size, rank = bmt.world_size(), bmt.rank()
|
||||
world_size = bmt.config["world_size"] // bmt.config["tp_size"]
|
||||
rank = bmt.config["topology"].tp_idx
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
nworkers, worker_id = 1, 1
|
||||
else:
|
||||
nworkers, worker_id = worker_info.num_workers, worker_info.id
|
||||
# print("concurrent_info: (world_size, rank, nworkers, worker_id): {}".format((world_size, rank, nworkers, worker_id)))
|
||||
return world_size, rank, nworkers, worker_id
|
||||
|
||||
|
||||
|
@ -56,14 +157,45 @@ class IndexedDataset(Dataset):
|
|||
self.max_retry = max_retry
|
||||
self.retry_sleep = retry_sleep
|
||||
self.bounds = None
|
||||
self.h5file = None
|
||||
self.build_index()
|
||||
|
||||
def size(self):
|
||||
return self.bounds[-1]
|
||||
|
||||
def _build_index_h5(self):
|
||||
index_path = os.path.join(self.path, "index.h5")
|
||||
if os.path.getsize(index_path) > 104857600:
|
||||
self.h5file = h5py.File(os.path.join(self.path, "index.h5"), "r")
|
||||
self.bounds = self.h5file["index"]
|
||||
else:
|
||||
# only load index into memory when it is small (< 100 Mb)
|
||||
# to avoid keeping to many file handlers
|
||||
self.h5file = None
|
||||
with h5py.File(index_path, "r") as hf:
|
||||
self.bounds = np.array(hf["index"])
|
||||
|
||||
def __del__(self):
|
||||
if self.h5file is not None:
|
||||
self.h5file.close()
|
||||
|
||||
def build_index(self):
|
||||
s = time.time()
|
||||
|
||||
txt_size = os.path.getsize(os.path.join(self.path, "index"))
|
||||
if txt_size > 0.5 * 1024**3 and os.path.exists(os.path.join(self.path, "index.h5")):
|
||||
source = "h5"
|
||||
self._build_index_h5()
|
||||
else:
|
||||
source = "txt"
|
||||
self._build_index_txt()
|
||||
e = time.time()
|
||||
bmt.print_rank("build_index_{} from {}, using {:.2f}s".format(source, self.path, e - s))
|
||||
|
||||
def _build_index_txt(self):
|
||||
with open(os.path.join(self.path, "index"), "r") as fin:
|
||||
self.bounds = [int(line) for line in fin]
|
||||
self.nlines = len(self.bounds)
|
||||
|
||||
def safe_read(self, i_or_s, offset, size):
|
||||
for retry in itertools.count():
|
||||
|
@ -138,39 +270,10 @@ class IndexedDataset(Dataset):
|
|||
class PrefetchDecodeDataset(IndexedDataset):
|
||||
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
|
||||
# Add safe decoding in iterator
|
||||
def __init__(self, *args, decode=json_decode, **kargs):
|
||||
def __init__(self, *args, decode=json_decode, allow_repeat=False, **kargs):
|
||||
super().__init__(*args, **kargs)
|
||||
self.decode = decode
|
||||
self.lock = threading.Lock()
|
||||
self.prev_used = set() # store previously used index in the checkpoint
|
||||
self.used = set() # track locally used index
|
||||
|
||||
def state_dict(self, gathered=True):
|
||||
if not self.prev_used and not self.used:
|
||||
return {"prev_used": set()}
|
||||
if gathered:
|
||||
used = torch.tensor(list(self.used)).cuda()
|
||||
size = torch.tensor(used.numel()).cuda()
|
||||
max_size = bmt.distributed.all_reduce(size, op="max")
|
||||
# allgather requires tensors having the same size
|
||||
used = torch.cat([used, torch.full((max_size - size,), -100, device=used.device)], dim=-1)
|
||||
all_used = bmt.distributed.all_gather(used).unique()
|
||||
all_used = set(all_used.tolist())
|
||||
if -100 in all_used:
|
||||
all_used.remove(-100) # remove the padding value
|
||||
all_used.union(self.prev_used)
|
||||
return {"prev_used": all_used}
|
||||
else:
|
||||
return {"prev_used": self.prev_used.union(self.used)}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
with self.lock:
|
||||
self.used = state.get("prev_used", set())
|
||||
|
||||
def reset(self):
|
||||
with self.lock:
|
||||
self.used = set()
|
||||
self.prev_used = set()
|
||||
self.allow_repeat = allow_repeat
|
||||
|
||||
def safe_decode(self, i, raw):
|
||||
if raw is None:
|
||||
|
@ -191,19 +294,23 @@ class PrefetchDecodeDataset(IndexedDataset):
|
|||
else:
|
||||
return self.safe_decode(key, raw)
|
||||
|
||||
def loader(self, q, lid, keys, stop):
|
||||
def loader(self, q, lid, keys, stop, used=None):
|
||||
# concurrent prefetching worker
|
||||
if used is None:
|
||||
used = BitSet()
|
||||
try:
|
||||
for key in keys:
|
||||
if stop.is_set():
|
||||
break
|
||||
# key is either a slice or an integer index
|
||||
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
|
||||
with self.lock:
|
||||
unused = set(index) - self.used - self.prev_used
|
||||
unused = bitset_diff(set(index), used)
|
||||
if not unused:
|
||||
# skip used slice / item
|
||||
continue
|
||||
if not q.empty():
|
||||
# avoid breaking the distributed file system with large io load
|
||||
time.sleep(random.random() * 2)
|
||||
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
|
||||
raw = super().__getitem__(key)
|
||||
if raw is None:
|
||||
|
@ -217,14 +324,14 @@ class PrefetchDecodeDataset(IndexedDataset):
|
|||
# signaling the end of iteration to the main thread
|
||||
q.put(StopIteration(lid))
|
||||
|
||||
def _iterate(self, key_groups, nprefetch=1000):
|
||||
def _iterate(self, key_groups, nprefetch=1000, used=None):
|
||||
# helper function for concurrent prefetching
|
||||
q = queue.Queue(maxsize=nprefetch)
|
||||
stop = threading.Event()
|
||||
alive = set()
|
||||
try:
|
||||
for lid, keys in enumerate(key_groups):
|
||||
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop), daemon=True)
|
||||
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop, used), daemon=True)
|
||||
loader.start()
|
||||
alive.add(lid)
|
||||
while True:
|
||||
|
@ -236,7 +343,7 @@ class PrefetchDecodeDataset(IndexedDataset):
|
|||
break
|
||||
else:
|
||||
# new item will be put later, wait for a while
|
||||
time.sleep(0.3)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
if isinstance(item, StopIteration):
|
||||
alive.remove(item.value)
|
||||
|
@ -245,16 +352,13 @@ class PrefetchDecodeDataset(IndexedDataset):
|
|||
data = self.safe_decode(i, raw)
|
||||
if data is None:
|
||||
continue
|
||||
self.used.add(i)
|
||||
yield data
|
||||
# automatically reset states with graceful ends.
|
||||
self.reset()
|
||||
yield i, data
|
||||
finally:
|
||||
# ask daemon loaders to stop
|
||||
stop.set()
|
||||
|
||||
def iterate(self, nthreads=3, prefetch_sample=100):
|
||||
world_size, rank, nworkers, worker_id = concurrent_info()
|
||||
def iterate(self, nthreads=3, prefetch_sample=100, used=None, process_group=None):
|
||||
world_size, rank, nworkers, worker_id = concurrent_info(process_group)
|
||||
nloaders = world_size * nworkers * nthreads
|
||||
if len(self) < nloaders:
|
||||
raise ValueError(
|
||||
|
@ -269,51 +373,61 @@ class PrefetchDecodeDataset(IndexedDataset):
|
|||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||
# split index among multi-threaded loaders
|
||||
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
|
||||
for data in self._iterate(id_groups, nprefetch=prefetch_sample):
|
||||
yield data
|
||||
return self._iterate(id_groups, nprefetch=prefetch_sample, used=used)
|
||||
|
||||
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=1000):
|
||||
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=500, used=None):
|
||||
world_size, rank, nworkers, worker_id = concurrent_info()
|
||||
nloaders = world_size * nworkers * nthreads
|
||||
if len(self) < nloaders:
|
||||
raise ValueError(
|
||||
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||||
f"please constrain either "
|
||||
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||||
)
|
||||
nslices = int(math.ceil(len(self) / slice_size))
|
||||
if not self.allow_repeat:
|
||||
raise ValueError(
|
||||
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||||
f"please constrain either "
|
||||
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||||
)
|
||||
else:
|
||||
duplicated_factor = math.ceil(nloaders / len(self))
|
||||
# In this case, slice size is 1
|
||||
r = Range(0, len(self), 1)
|
||||
# split index among grouped multi-gpu workers
|
||||
r = r.subrange(split=rank // duplicated_factor, nsplits=math.ceil(world_size / duplicated_factor))
|
||||
# # split index among multi-threaded loaders
|
||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||
else:
|
||||
nslices = int(math.ceil(len(self) / slice_size))
|
||||
|
||||
if nslices < nloaders:
|
||||
safe_print(
|
||||
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
|
||||
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
|
||||
)
|
||||
slice_size = len(self) // nloaders
|
||||
if nslices < nloaders:
|
||||
safe_print(
|
||||
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
|
||||
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
|
||||
)
|
||||
slice_size = len(self) // nloaders
|
||||
|
||||
# we only iteratre through start ids as they uniquely mark each slice
|
||||
r = Range(0, len(self), slice_size)
|
||||
# split index among multi-gpu workers
|
||||
r = r.subrange(split=rank, nsplits=world_size)
|
||||
# split index among multi-process dataloader workers
|
||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||
# split index among multi-threaded loaders
|
||||
# we only iteratre through start ids as they uniquely mark each slice
|
||||
r = Range(0, len(self), slice_size)
|
||||
# split index among multi-gpu workers
|
||||
r = r.subrange(split=rank, nsplits=world_size)
|
||||
# split index among multi-process dataloader workers
|
||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||
# split index among multi-threaded loaders
|
||||
slice_groups = [
|
||||
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
|
||||
]
|
||||
for data in self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size):
|
||||
yield data
|
||||
return self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size, used=used)
|
||||
|
||||
|
||||
class IndexedDatasetBuilder:
|
||||
def __init__(self, path, overwrite=False):
|
||||
self.path = path
|
||||
self.index_path = os.path.join(self.path, "index")
|
||||
self.index_path = os.path.join(self.path, "index.h5")
|
||||
self.index_path_txt = os.path.join(self.path, "index")
|
||||
self.data_path = os.path.join(self.path, "data.jsonl")
|
||||
if not overwrite:
|
||||
assert not os.path.exists(self.data_path)
|
||||
assert not os.path.exists(self.index_path)
|
||||
assert not os.path.exists(self.index_path_txt)
|
||||
self.fout = None
|
||||
self.starts = []
|
||||
self.bounds = []
|
||||
self.offset = 0
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -322,15 +436,17 @@ class IndexedDatasetBuilder:
|
|||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.starts.append(self.offset)
|
||||
with open(self.index_path, "w") as fout:
|
||||
for s in self.starts:
|
||||
fout.write(f"{s}\n")
|
||||
self.bounds.append(self.offset)
|
||||
with h5py.File(os.path.join(self.index_path), "w") as hf:
|
||||
hf.create_dataset("index", data=self.bounds)
|
||||
with open(self.index_path_txt, "w") as fout_txt:
|
||||
for s in self.bounds:
|
||||
fout_txt.write(f"{s}\n")
|
||||
self.fout.close()
|
||||
|
||||
def put(self, data: dict):
|
||||
s = json_encode(data) + b"\n"
|
||||
self.starts.append(self.offset)
|
||||
self.bounds.append(self.offset)
|
||||
self.offset += len(s)
|
||||
self.fout.write(s)
|
||||
|
|
@ -1,3 +1,18 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/08/07
|
|
@ -1,14 +1,23 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/07/27
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from cpm.dataset import SimpleDataset
|
||||
from cpm.dataset.indexed_dataset import IndexedDatasetBuilder
|
||||
from fm9g.dataset import SimpleDataset
|
||||
from fm9g.dataset.indexed_dataset import IndexedDatasetBuilder
|
||||
|
||||
|
||||
def convert_cpm_data(cpm_path, out_path):
|
||||
dataset = SimpleDataset(cpm_path, shuffle=False)
|
||||
def convert_fm9g_data(fm9g_path, out_path):
|
||||
dataset = SimpleDataset(fm9g_path, shuffle=False)
|
||||
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
|
||||
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
|
||||
builder.put(dataset.read())
|
||||
|
@ -16,7 +25,7 @@ def convert_cpm_data(cpm_path, out_path):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", "-i", required=True, help="Data path in CPM format.")
|
||||
parser.add_argument("--input", "-i", required=True, help="Data path in fm9g format.")
|
||||
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
|
||||
args = parser.parse_args()
|
||||
convert_cpm_data(args.input, args.output)
|
||||
convert_fm9g_data(args.input, args.output)
|
|
@ -1,3 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/08/07
|
||||
import argparse
|
||||
import os
|
||||
|
|
@ -1,3 +1,18 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
@ -251,7 +266,7 @@ def merge_dataset(dst: str, src: str):
|
|||
_write_info_list(meta_path_dst, nw_info)
|
||||
|
||||
|
||||
def to_cpm(src_data, dst_path, dst_name):
|
||||
def to_fm9g(src_data, dst_path, dst_name):
|
||||
if not os.path.exists(dst_path):
|
||||
os.makedirs(dst_path)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"folders": [
|
||||
{
|
||||
"path": "../.."
|
||||
}
|
||||
],
|
||||
"settings": {}
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class DragonflyConfig(PretrainedConfig):
|
||||
model_type = "fm9g_dragonfly"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_key_value_heads": "num_kv_heads",
|
||||
"hidden_act": "activate_fn",
|
||||
"hidden_size": "dim_model",
|
||||
"num_attention_heads": "num_heads",
|
||||
"intermediate_size": "dim_ff",
|
||||
"num_hidden_layers": "num_layers",
|
||||
"vocab_size": "vocab_size",
|
||||
"rms_norm_eps": "eps",
|
||||
"scale_emb": "scale_emb",
|
||||
"scale_depth": "scale_depth",
|
||||
"scale": "scale",
|
||||
"attention_scale": "attention_scale",
|
||||
"qk_norm": "qk_norm",
|
||||
"ffn_gated": "ffn_gated",
|
||||
} # model specific to common
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=122753, # TODO: do we need to change to 122880 = 960 * 128?
|
||||
dim_model=4096,
|
||||
num_heads=32,
|
||||
num_kv_heads=32,
|
||||
dim_head=128,
|
||||
dim_ff=11008,
|
||||
num_layers=32,
|
||||
dropout_p=0.0,
|
||||
activate_fn="silu",
|
||||
scale=False,
|
||||
scale_emb: float = 1.0,
|
||||
scale_depth: float = -1,
|
||||
dim_model_base: int = 256,
|
||||
eps=1e-5,
|
||||
init_std=0.02,
|
||||
dtype="bf16",
|
||||
base=10000,
|
||||
qk_norm=False,
|
||||
tie_lm_head=False,
|
||||
max_length=8192,
|
||||
pose_prob=0.0,
|
||||
pose_scaling_factor=1,
|
||||
rope_scaling_type="",
|
||||
rope_scaling_factor=1,
|
||||
orig_max_length=8192,
|
||||
tp=0,
|
||||
use_checkpoint=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.dim_model = dim_model
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.dim_head = dim_head
|
||||
self.dim_ff = dim_ff
|
||||
self.num_layers = num_layers
|
||||
self.dropout_p = dropout_p
|
||||
self.activate_fn = activate_fn
|
||||
self.scale = scale
|
||||
self.scale_emb = scale_emb
|
||||
self._dtype = dtype
|
||||
self.dim_model_base = dim_model_base
|
||||
self.scale_depth = scale_depth
|
||||
self.eps = eps
|
||||
self.init_std = init_std
|
||||
self.base = base
|
||||
self.qk_norm = qk_norm
|
||||
self.tie_lm_head = tie_lm_head
|
||||
self.use_bfloat16 = True if self._dtype == "bf16" else False
|
||||
self.pose_prob = pose_prob
|
||||
self.pose_scaling_factor = pose_scaling_factor
|
||||
self.rope_scaling_type = rope_scaling_type
|
||||
self.rope_scaling_factor = rope_scaling_factor
|
||||
self.max_length = max_length
|
||||
self.orig_max_length = orig_max_length
|
||||
self.use_checkpoint = use_checkpoint
|
||||
print("use_checkpoint", self.use_checkpoint)
|
||||
self.tp = tp
|
||||
super().__init__(architectures=["fm9gDragonflyForCausalLM"])
|
||||
|
||||
@property
|
||||
def scale_width(
|
||||
self,
|
||||
):
|
||||
if self.scale:
|
||||
return self.dim_model / self.dim_model_base
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
@property
|
||||
def dtype(
|
||||
self,
|
||||
): # -> Any | None:
|
||||
if self._dtype == "bf16":
|
||||
return torch.bfloat16
|
||||
elif self._dtype == "fp16":
|
||||
return torch.half
|
||||
elif self._dtype == "float32":
|
||||
return torch.float
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1 @@
|
|||
from .pretrain_indexed import MixedIndexedDataset
|
|
@ -0,0 +1,74 @@
|
|||
import logging
|
||||
from multiprocessing import Lock
|
||||
|
||||
from flask import Flask
|
||||
from flask import jsonify
|
||||
from flask import request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 获取 Werkzeug 日志记录器并设置日志级别
|
||||
log = logging.getLogger("werkzeug")
|
||||
log.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class GlobalAvgTokensStat(object):
|
||||
def __init__(self, decay_factor: float = 0.98):
|
||||
self._avg_tokens = {}
|
||||
self.decay_factor = decay_factor
|
||||
self.lock = Lock()
|
||||
self.task_locks = {}
|
||||
|
||||
def set_avg_tokens(self, task_name, avg_tokens):
|
||||
self._register_task_lock_helper(task_name)
|
||||
with self.task_locks[task_name]:
|
||||
self._avg_tokens[task_name] = avg_tokens
|
||||
|
||||
def update_avg_tokens_by_ema(self, task_name, length):
|
||||
self._register_task_lock_helper(task_name)
|
||||
with self.task_locks[task_name]:
|
||||
if task_name in self._avg_tokens and self._avg_tokens[task_name] > 0:
|
||||
self._avg_tokens[task_name] = self._avg_tokens[task_name] * self.decay_factor + length * (
|
||||
1 - self.decay_factor
|
||||
)
|
||||
else:
|
||||
self._avg_tokens[task_name] = length
|
||||
|
||||
def get_avg_tokens(self, task_name):
|
||||
self._register_task_lock_helper(task_name)
|
||||
with self.task_locks[task_name]:
|
||||
return self._avg_tokens.get(task_name, -1)
|
||||
|
||||
def _register_task_lock_helper(self, task_name):
|
||||
if task_name not in self.task_locks:
|
||||
with self.lock:
|
||||
if task_name not in self.task_locks:
|
||||
self.task_locks[task_name] = Lock()
|
||||
|
||||
|
||||
global_avg_tokens_stat = GlobalAvgTokensStat()
|
||||
|
||||
|
||||
@app.route("/avg_tokens/<path:task_name>", methods=["GET"])
|
||||
def get_avg_tokens(task_name):
|
||||
global global_avg_tokens_stat
|
||||
avg_tokens = global_avg_tokens_stat.get_avg_tokens(task_name)
|
||||
return jsonify({"avg_tokens": avg_tokens})
|
||||
|
||||
|
||||
@app.route("/avg_tokens/<path:task_name>", methods=["POST"])
|
||||
def set_avg_tokens(task_name):
|
||||
global global_avg_tokens_stat
|
||||
action = request.args.get("action", "update", type=str)
|
||||
length = request.args.get("length", -1, type=int)
|
||||
if action == "set":
|
||||
global_avg_tokens_stat.set_avg_tokens(task_name, length)
|
||||
elif action == "update":
|
||||
global_avg_tokens_stat.update_avg_tokens_by_ema(task_name, length)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=5000, debug=True)
|
|
@ -0,0 +1,826 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
import functools
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from fm9g.dataset import PrefetchDecodeDataset
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.vdc_sampling import van_der_corput
|
||||
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
IGNORE_TGT = -100
|
||||
|
||||
|
||||
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||
if cfg_json_str is not None:
|
||||
cfgs = json.loads(cfg_json_str)
|
||||
else:
|
||||
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||
cfgs = json.load(fin)
|
||||
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||
|
||||
path_dict = None
|
||||
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||
try:
|
||||
with open(platform_config_path, "r") as f:
|
||||
platform_cfg = json.load(f)
|
||||
path_dict = platform_cfg["dataset_map"]
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||
except Exception as e:
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||
|
||||
task_name2dataset_name = dict()
|
||||
for idx, cfg in enumerate(cfgs):
|
||||
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||
# to be delibrately annoying :)
|
||||
if cfg["task_name"] in task_name2dataset_name:
|
||||
raise ValueError(
|
||||
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||
)
|
||||
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||
|
||||
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||
# if path_dict is not None:
|
||||
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||
|
||||
# dealing with optional configs
|
||||
if "weight" in cfg:
|
||||
assert isinstance(cfg["weight"], (float, int))
|
||||
else:
|
||||
cfg["weight"] = 1.0
|
||||
|
||||
if "oversize_rule" in cfg:
|
||||
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||
else:
|
||||
cfg["oversize_rule"] = "segment"
|
||||
|
||||
if "transforms" in cfg:
|
||||
assert isinstance(cfg["transforms"], str)
|
||||
# dealing with relative path
|
||||
if not cfg["transforms"].startswith("/"):
|
||||
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||
if not cfg["transforms"]:
|
||||
cfg["transforms"] = None
|
||||
else:
|
||||
cfg["transforms"] = None
|
||||
|
||||
if "incontext_weight" in cfg:
|
||||
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||
else:
|
||||
cfg["incontext_weight"] = [1.0]
|
||||
cfg["id"] = idx
|
||||
# dataset and iterator will be built
|
||||
return cfgs
|
||||
|
||||
|
||||
def data2ids(data, tokenizer, max_length):
|
||||
text = "\n".join(
|
||||
[
|
||||
data.get("title", "").strip(),
|
||||
data.get("question", "").strip(),
|
||||
data.get("answer", "").strip(),
|
||||
data.get("abstract", "").strip(),
|
||||
data.get("text", "").strip(),
|
||||
data.get("code", "").strip(),
|
||||
]
|
||||
).strip()
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||
yield from ()
|
||||
return
|
||||
# suppress the annoying warning from tokenizer
|
||||
ids = (
|
||||
[tokenizer.bos_token_id]
|
||||
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||
+ [tokenizer.eos_token_id]
|
||||
)
|
||||
src_ids = ids[0:-1]
|
||||
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||
|
||||
if len(src_ids) > max_length:
|
||||
for st in range(0, len(src_ids), max_length):
|
||||
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||
else:
|
||||
yield src_ids, tgt_ids
|
||||
|
||||
|
||||
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||
assert oversize_rule in ("drop", "head", "segment")
|
||||
if data is None:
|
||||
yield from ()
|
||||
return
|
||||
if "output" not in data or not data["output"]:
|
||||
yield from ()
|
||||
return
|
||||
if "input" not in data or data["input"] is None:
|
||||
data["input"] = ""
|
||||
|
||||
src_ids = [tokenizer.bos_token_id]
|
||||
tgt_ids = []
|
||||
has_input = False
|
||||
is_segment_reenter = False
|
||||
|
||||
# Use incremental tokenization to avoid waiting for a long document
|
||||
MAX_CHUNK_LENGTH = max_length * 10
|
||||
for part in ("input", "output"):
|
||||
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||
while l < len(data[part]):
|
||||
try:
|
||||
current_slice = data[part][l:r]
|
||||
if not current_slice:
|
||||
break
|
||||
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||
except:
|
||||
print("Error in data[part][l:r] {}".format(data))
|
||||
yield from ()
|
||||
return
|
||||
|
||||
if part == "input":
|
||||
if len(token_ids) > 0:
|
||||
has_input = True
|
||||
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||
yield from ()
|
||||
return
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||
l = r
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
else:
|
||||
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||
if oversize_rule == "drop":
|
||||
yield from ()
|
||||
return
|
||||
elif oversize_rule == "head":
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
elif oversize_rule == "segment":
|
||||
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||
if has_input: # is instruction data
|
||||
if (
|
||||
do_compact
|
||||
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||
and instruction_rest_space / len(src_ids) > 0.8
|
||||
): # can be squeezed into max length
|
||||
inputs_len = len(src_ids)
|
||||
keep_len = instruction_rest_space // 2
|
||||
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else: # else use head rule
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
else: # normal segment
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids)
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
src_ids = src_ids[max_length:]
|
||||
tgt_ids = tgt_ids[max_length:]
|
||||
# sliding input str window
|
||||
consumed_str = tokenizer.decode(selected_token_ids)
|
||||
l += len(consumed_str)
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
is_segment_reenter = True
|
||||
else:
|
||||
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||
not is_segment_reenter and len(token_ids) > 0
|
||||
): # is segmented LM data
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else:
|
||||
yield from ()
|
||||
return
|
||||
|
||||
|
||||
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
tokenizer,
|
||||
max_length=1024,
|
||||
transform_func=None,
|
||||
nthreads=1,
|
||||
prefetch_slice=3,
|
||||
slice_size=500,
|
||||
do_compact=False,
|
||||
):
|
||||
super(SegmentedDataset, self).__init__()
|
||||
self.segment = functools.partial(
|
||||
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||
)
|
||||
self.cfg = cfg
|
||||
self.max_length = max_length
|
||||
self.nthreads = nthreads
|
||||
self.transform_func = transform_func
|
||||
self.prefetch_slice = prefetch_slice
|
||||
self.slice_size = slice_size
|
||||
self.abs_weight = cfg.get("abs_weight", None)
|
||||
self.task_name = cfg["task_name"]
|
||||
self.dataset_name = cfg["dataset_name"]
|
||||
self.oversize_rule = cfg["oversize_rule"]
|
||||
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
|
||||
self.exhausted = False
|
||||
self.iterator = None
|
||||
|
||||
self.counter = 0
|
||||
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||
self.used = BitSet()
|
||||
self.init_ave_tokens()
|
||||
|
||||
def init_ave_tokens(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
except FileNotFoundError:
|
||||
bmt.print_rank(
|
||||
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
)
|
||||
shm = SharedMemory(
|
||||
create=True,
|
||||
size=ctypes.sizeof(ctypes.c_float),
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
|
||||
)
|
||||
|
||||
# 使用共享内存
|
||||
shared_value = ctypes.c_float.from_buffer(shm.buf)
|
||||
_ave_tokens = self.cfg.get(
|
||||
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
|
||||
)
|
||||
|
||||
if _ave_tokens > self.max_length:
|
||||
_ave_tokens = self.max_length
|
||||
bmt.print_rank(
|
||||
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
|
||||
_ave_tokens, self.max_length
|
||||
)
|
||||
)
|
||||
|
||||
shared_value.value = _ave_tokens
|
||||
# 不再需要 shared_value 时,删除引用
|
||||
del shared_value
|
||||
|
||||
# 现在可以安全地关闭共享内存
|
||||
shm.close()
|
||||
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
|
||||
|
||||
@property
|
||||
def ave_tokens(
|
||||
self,
|
||||
):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
tmp = shared_value.value
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
return tmp
|
||||
|
||||
def ave_tokens_update(self, length):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
if shared_value.value < 0:
|
||||
shared_value.value = float(length)
|
||||
else:
|
||||
shared_value.value = 0.98 * shared_value.value + 0.02 * length
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
|
||||
def size(self):
|
||||
return self.dataset.size()
|
||||
|
||||
def __iter__(self):
|
||||
self.iterate()
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.exhausted = False
|
||||
if self.iterator is not None:
|
||||
self.iterator.close()
|
||||
self.iterator = None
|
||||
self.used = BitSet()
|
||||
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||
|
||||
def transform(self, data: dict) -> dict:
|
||||
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||
weight = weight / weight.sum()
|
||||
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||
return self.transform_func(data, num_incontext, random.Random())
|
||||
|
||||
def segment_iterate(self, sample_iter):
|
||||
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||
self.ave_tokens_update(len(src_ids)) # 0 for input ids
|
||||
yield src_ids, tgt_ids, index
|
||||
|
||||
def iterate(self):
|
||||
# make the dataset itself an iterator
|
||||
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||
self.iterator = self.segment_iterate(sample_iter)
|
||||
|
||||
def __next__(self):
|
||||
# advance the task iterator
|
||||
if self.iterator is None:
|
||||
self.iterate()
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if state_dict.get("exhausted", False):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
used = state_dict.get("used", BitSet())
|
||||
if len(used) == len(self.dataset):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
self.exhausted = False
|
||||
self.used = used
|
||||
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
|
||||
|
||||
def state_dict(self):
|
||||
if len(self.used) == len(self.dataset):
|
||||
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
|
||||
else:
|
||||
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
|
||||
|
||||
def update_state(self, indice):
|
||||
self.used.update(indice)
|
||||
|
||||
|
||||
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str,
|
||||
cfg_json_str,
|
||||
tokenizer,
|
||||
max_length,
|
||||
weight_by_size: bool = True,
|
||||
nthreads=5,
|
||||
prefetch_slice=100,
|
||||
parallel_loading=False,
|
||||
vdc_sampling=False,
|
||||
update_weights_frequency=1,
|
||||
seed=42,
|
||||
):
|
||||
super(MixedIndexedDataset, self).__init__()
|
||||
self.set_seed(seed + bmt.rank())
|
||||
self.weight_by_size = weight_by_size
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.bos_token_id = self.tokenizer.bos_token_id
|
||||
self.path2transform = dict()
|
||||
self.task_dict = OrderedDict()
|
||||
self.nthreads = nthreads
|
||||
self.prefetch_slice = prefetch_slice
|
||||
# useful for indexing
|
||||
self.tasks = []
|
||||
self.names = []
|
||||
# ending of iteration
|
||||
self.remain = 0
|
||||
self.max_length = max_length
|
||||
self.vdc_sampling = vdc_sampling
|
||||
if self.vdc_sampling:
|
||||
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
|
||||
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||
|
||||
self.update_weights_frequency = update_weights_frequency
|
||||
|
||||
self.path2transform = dict()
|
||||
|
||||
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||
|
||||
if parallel_loading:
|
||||
self.parallel_load(cfgs, max_workers=None)
|
||||
else:
|
||||
self.sequential_load(cfgs)
|
||||
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def set_seed(self, seed):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def load_task(self, cfg):
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
return task
|
||||
|
||||
def sequential_load(self, cfgs):
|
||||
self.cfgs = cfgs
|
||||
for cfg in cfgs:
|
||||
# python3.7 and later preserves insertion order to dictionary
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
self.task_dict[task.task_name] = task
|
||||
self.tasks.append(task)
|
||||
self.names.append(task.task_name)
|
||||
self.remain += 1
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
missing_keys = []
|
||||
for name, task in self.task_dict.items():
|
||||
if name in state_dict:
|
||||
task.load_state_dict(state_dict[name])
|
||||
else:
|
||||
missing_keys.append(name)
|
||||
self.update_weights()
|
||||
return missing_keys
|
||||
|
||||
def save_state_dict(self, path):
|
||||
state_dict = {}
|
||||
for name, task in self.task_dict.items():
|
||||
_state_dict = task.state_dict()
|
||||
if isinstance(_state_dict["used"], BitSet):
|
||||
bitset = _state_dict["used"]
|
||||
_file_name = bitset.save(path)
|
||||
_state_dict["used"] = _file_name
|
||||
state_dict[name] = _state_dict
|
||||
else:
|
||||
state_dict[name] = task.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
logger.info("Dataset state saved")
|
||||
|
||||
def update_states(self, task_ids, indice):
|
||||
is_dict = isinstance(indice, dict)
|
||||
uniq = torch.unique(task_ids)
|
||||
for idx in uniq:
|
||||
idx = idx.item()
|
||||
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||
self.tasks[idx].update_state(indexes)
|
||||
|
||||
def get_transform_func(self, module_name: str, transform_script_path):
|
||||
if transform_script_path is None:
|
||||
# allow null transform
|
||||
return lambda data, num_incontext, rand: data
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||
if transform_script_path not in self.path2transform:
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
if spec is None:
|
||||
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.path2transform[transform_script_path] = {
|
||||
"loader": loader,
|
||||
"module": mod,
|
||||
"last_mtime": 0,
|
||||
}
|
||||
transform_script_info = self.path2transform[transform_script_path]
|
||||
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||
if curr_mtime > transform_script_info["last_mtime"]:
|
||||
transform_script_info["last_mtime"] = curr_mtime
|
||||
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||
if transform_func is None:
|
||||
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||
return transform_func
|
||||
|
||||
def update_weights(self):
|
||||
task0 = self.tasks[0]
|
||||
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||
weights = []
|
||||
for task in self.tasks:
|
||||
if task.exhausted:
|
||||
weights.append(0)
|
||||
else:
|
||||
if task.ave_tokens == -1:
|
||||
weights.append(task.abs_weight / self.max_length)
|
||||
else:
|
||||
weights.append(task.abs_weight / task.ave_tokens)
|
||||
weights = np.array(weights)
|
||||
else:
|
||||
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||
if self.weight_by_size:
|
||||
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||
weights *= sizes
|
||||
self.weights = weights / weights.sum()
|
||||
|
||||
def __iter__(self):
|
||||
for task in self.tasks:
|
||||
task.iterate()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
step = 1
|
||||
while True:
|
||||
if self.remain == 0:
|
||||
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||
raise StopIteration
|
||||
if self.vdc_sampling:
|
||||
idx = next(self.vdc_gen)(self.weights)
|
||||
else:
|
||||
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||
|
||||
data = next(self.tasks[idx])
|
||||
if step % self.update_weights_frequency == 0:
|
||||
self.update_weights()
|
||||
if data is None:
|
||||
if self.tasks[idx].allow_repeat:
|
||||
# _runtime_ave = self.tasks[idx].ave_tokens
|
||||
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
# self.tasks[idx] = SegmentedDataset(
|
||||
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||
# )
|
||||
# self.tasks[idx].ave_tokens_update(_runtime_ave)
|
||||
self.tasks[idx].reset()
|
||||
else:
|
||||
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
self.tasks[idx].exhaust = True
|
||||
self.remain -= 1
|
||||
continue
|
||||
step += 1
|
||||
return dict(
|
||||
task_id=idx,
|
||||
input=data[0],
|
||||
target=data[1],
|
||||
index=data[2],
|
||||
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||
)
|
||||
|
||||
|
||||
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||
self.max_total_length = batch_size * max_length
|
||||
self.batch_size = 1
|
||||
# setting compact=True concats segments orignated from the same input
|
||||
# into a long sequence. the relative order of segments should be preserved
|
||||
# in mixed_dataset, e.g.,
|
||||
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||
self.compact = compact
|
||||
|
||||
self.total_length = 0
|
||||
self.task2seqs = defaultdict(list)
|
||||
self.mixed_dataset = mixed_dataset
|
||||
self._max_length = max_length
|
||||
self._pose_prob = pose_prob
|
||||
self._pose_scaling_factor = pose_scaling_factor
|
||||
if self._pose_prob > 0.0:
|
||||
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||
else:
|
||||
self._scaled_max_length = max_length
|
||||
|
||||
def put(self, sample):
|
||||
self.total_length += len(sample["target"])
|
||||
task_id = sample["task_id"]
|
||||
if self.compact and self.task2seqs[task_id]:
|
||||
last = self.task2seqs[task_id][-1]
|
||||
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||
# concatenate sequantial segments for longer context modeling: why not?
|
||||
last["input"].extend(sample["input"])
|
||||
last["target"].extend(sample["target"])
|
||||
return
|
||||
self.task2seqs[task_id].append(sample)
|
||||
|
||||
def _pose_preprocess(
|
||||
self,
|
||||
input_ids: NDArray[np.int32],
|
||||
) -> NDArray[np.int32]:
|
||||
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||
"""
|
||||
len_chunk = min(len(input_ids), self._max_length)
|
||||
len_input = len(input_ids)
|
||||
# Chunk input randomly to fit max_length if needed
|
||||
lt1 = 0
|
||||
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||
# Generate PoSE position ids
|
||||
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||
len_position_ids = len(position_ids)
|
||||
lt = 0
|
||||
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||
position_ids[: rt1 - lt1] += lt
|
||||
position_ids[rt1 - lt1 :] += rt
|
||||
return position_ids
|
||||
|
||||
def pop(self):
|
||||
indexes = defaultdict(list)
|
||||
lengths = []
|
||||
|
||||
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
|
||||
span_begin = 0
|
||||
for samples in self.task2seqs.values():
|
||||
while samples:
|
||||
sample = samples.pop()
|
||||
span_end = span_begin + len(sample["input"])
|
||||
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||
else:
|
||||
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||
lengths.append(len(sample["target"]))
|
||||
indexes[int(sample["task_id"])].append(sample["index"])
|
||||
self.total_length -= len(sample["target"])
|
||||
span_begin = span_end
|
||||
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||
dim=0,
|
||||
).int()
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"targets": targets,
|
||||
"task_ids": task_ids,
|
||||
"indexes": indexes,
|
||||
# adhere to flash attention interface
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||
"lengths": torch.tensor(sum(lengths)).int(),
|
||||
"task_names": self.mixed_dataset.names,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return batch
|
||||
|
||||
def will_be_full(self, sample):
|
||||
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.mixed_dataset:
|
||||
if self.will_be_full(sample):
|
||||
yield self.pop()
|
||||
self.put(sample)
|
||||
|
||||
|
||||
class CudaPrefetcher(Iterable):
|
||||
"""
|
||||
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||
self.loader = iter(loader)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
if self.tp_rank == 0:
|
||||
data = next(self.loader)
|
||||
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||
d = {}
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
np_cur_data = data[key].cpu().numpy()
|
||||
bs = np_cur_data.tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||
elif isinstance(data[key], np.ndarray):
|
||||
bs = data[key].tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||
else:
|
||||
d[key] = data[key]
|
||||
try:
|
||||
_ = json.dumps(d)
|
||||
except TypeError:
|
||||
print(d)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||
json.dump(d, f)
|
||||
bmt.synchronize()
|
||||
if self.tp_rank != 0:
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||
data = json.load(f)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||
bs = fb.read()
|
||||
offset = 0
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||
data[key][3:]
|
||||
)
|
||||
offset = nw_offset
|
||||
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = torch.from_numpy(
|
||||
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||
.reshape(data[key][3:])
|
||||
.copy()
|
||||
)
|
||||
offset = nw_offset
|
||||
self.data = data
|
||||
else:
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
with torch.cuda.stream(self.stream):
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key].record_stream(torch.cuda.current_stream())
|
||||
data = copy.deepcopy(self.data)
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
|
@ -0,0 +1,828 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
import functools
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from fm9g.dataset import PrefetchDecodeDataset
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.vdc_sampling import van_der_corput
|
||||
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
IGNORE_TGT = -100
|
||||
|
||||
|
||||
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||
if cfg_json_str is not None:
|
||||
cfgs = json.loads(cfg_json_str)
|
||||
else:
|
||||
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||
cfgs = json.load(fin)
|
||||
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||
|
||||
path_dict = None
|
||||
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||
try:
|
||||
with open(platform_config_path, "r") as f:
|
||||
platform_cfg = json.load(f)
|
||||
path_dict = platform_cfg["dataset_map"]
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||
except Exception as e:
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||
|
||||
task_name2dataset_name = dict()
|
||||
for idx, cfg in enumerate(cfgs):
|
||||
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||
# to be delibrately annoying :)
|
||||
if cfg["task_name"] in task_name2dataset_name:
|
||||
raise ValueError(
|
||||
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||
)
|
||||
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||
|
||||
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||
# if path_dict is not None:
|
||||
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||
|
||||
# dealing with optional configs
|
||||
if "weight" in cfg:
|
||||
assert isinstance(cfg["weight"], (float, int))
|
||||
else:
|
||||
cfg["weight"] = 1.0
|
||||
|
||||
if "oversize_rule" in cfg:
|
||||
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||
else:
|
||||
cfg["oversize_rule"] = "segment"
|
||||
|
||||
if "transforms" in cfg:
|
||||
assert isinstance(cfg["transforms"], str)
|
||||
# dealing with relative path
|
||||
if not cfg["transforms"].startswith("/"):
|
||||
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||
if not cfg["transforms"]:
|
||||
cfg["transforms"] = None
|
||||
else:
|
||||
cfg["transforms"] = None
|
||||
|
||||
if "incontext_weight" in cfg:
|
||||
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||
else:
|
||||
cfg["incontext_weight"] = [1.0]
|
||||
cfg["id"] = idx
|
||||
# dataset and iterator will be built
|
||||
return cfgs
|
||||
|
||||
|
||||
def data2ids(data, tokenizer, max_length):
|
||||
text = "\n".join(
|
||||
[
|
||||
data.get("title", "").strip(),
|
||||
data.get("question", "").strip(),
|
||||
data.get("answer", "").strip(),
|
||||
data.get("abstract", "").strip(),
|
||||
data.get("text", "").strip(),
|
||||
data.get("code", "").strip(),
|
||||
]
|
||||
).strip()
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||
yield from ()
|
||||
return
|
||||
# suppress the annoying warning from tokenizer
|
||||
ids = (
|
||||
[tokenizer.bos_id]
|
||||
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||
+ [tokenizer.eos_id]
|
||||
)
|
||||
src_ids = ids[0:-1]
|
||||
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||
|
||||
if len(src_ids) > max_length:
|
||||
for st in range(0, len(src_ids), max_length):
|
||||
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||
else:
|
||||
yield src_ids, tgt_ids
|
||||
|
||||
|
||||
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||
assert oversize_rule in ("drop", "head", "segment")
|
||||
if data is None:
|
||||
yield from ()
|
||||
return
|
||||
if "output" not in data or not data["output"]:
|
||||
yield from ()
|
||||
return
|
||||
if "input" not in data or data["input"] is None:
|
||||
data["input"] = ""
|
||||
|
||||
src_ids = [tokenizer.bos_id]
|
||||
tgt_ids = []
|
||||
has_input = False
|
||||
is_segment_reenter = False
|
||||
|
||||
# Use incremental tokenization to avoid waiting for a long document
|
||||
MAX_CHUNK_LENGTH = max_length * 10
|
||||
for part in ("input", "output"):
|
||||
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||
while l < len(data[part]):
|
||||
try:
|
||||
current_slice = data[part][l:r]
|
||||
if not current_slice:
|
||||
break
|
||||
#token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||
token_ids = tokenizer.encode(current_slice)
|
||||
except:
|
||||
#print("Error in data[part][l:r] {}".format(data))
|
||||
yield from ()
|
||||
return
|
||||
|
||||
if part == "input":
|
||||
if len(token_ids) > 0:
|
||||
has_input = True
|
||||
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||
yield from ()
|
||||
return
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||
l = r
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
else:
|
||||
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||
if oversize_rule == "drop":
|
||||
yield from ()
|
||||
return
|
||||
elif oversize_rule == "head":
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
elif oversize_rule == "segment":
|
||||
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||
if has_input: # is instruction data
|
||||
if (
|
||||
do_compact
|
||||
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||
and instruction_rest_space / len(src_ids) > 0.8
|
||||
): # can be squeezed into max length
|
||||
inputs_len = len(src_ids)
|
||||
keep_len = instruction_rest_space // 2
|
||||
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_id)
|
||||
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else: # else use head rule
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
else: # normal segment
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids)
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
src_ids = src_ids[max_length:]
|
||||
tgt_ids = tgt_ids[max_length:]
|
||||
# sliding input str window
|
||||
consumed_str = tokenizer.decode(selected_token_ids)
|
||||
l += len(consumed_str)
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
is_segment_reenter = True
|
||||
else:
|
||||
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||
not is_segment_reenter and len(token_ids) > 0
|
||||
): # is segmented LM data
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_id)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else:
|
||||
yield from ()
|
||||
return
|
||||
|
||||
|
||||
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
tokenizer,
|
||||
max_length=1024,
|
||||
transform_func=None,
|
||||
nthreads=1,
|
||||
prefetch_slice=3,
|
||||
slice_size=500,
|
||||
do_compact=False,
|
||||
):
|
||||
super(SegmentedDataset, self).__init__()
|
||||
self.segment = functools.partial(
|
||||
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||
)
|
||||
self.cfg = cfg
|
||||
self.max_length = max_length
|
||||
self.nthreads = nthreads
|
||||
self.transform_func = transform_func
|
||||
self.prefetch_slice = prefetch_slice
|
||||
self.slice_size = slice_size
|
||||
self.abs_weight = cfg.get("abs_weight", None)
|
||||
self.task_name = cfg["task_name"]
|
||||
self.dataset_name = cfg["dataset_name"]
|
||||
self.oversize_rule = cfg["oversize_rule"]
|
||||
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
|
||||
self.exhausted = False
|
||||
self.iterator = None
|
||||
|
||||
self.counter = 0
|
||||
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||
self.used = BitSet()
|
||||
self.init_ave_tokens()
|
||||
|
||||
def init_ave_tokens(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
except FileNotFoundError:
|
||||
bmt.print_rank(
|
||||
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
)
|
||||
shm = SharedMemory(
|
||||
create=True,
|
||||
size=ctypes.sizeof(ctypes.c_float),
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
|
||||
)
|
||||
|
||||
# 使用共享内存
|
||||
shared_value = ctypes.c_float.from_buffer(shm.buf)
|
||||
_ave_tokens = self.cfg.get(
|
||||
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
|
||||
)
|
||||
|
||||
if _ave_tokens > self.max_length:
|
||||
_ave_tokens = self.max_length
|
||||
bmt.print_rank(
|
||||
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
|
||||
_ave_tokens, self.max_length
|
||||
)
|
||||
)
|
||||
|
||||
shared_value.value = _ave_tokens
|
||||
# 不再需要 shared_value 时,删除引用
|
||||
del shared_value
|
||||
|
||||
# 现在可以安全地关闭共享内存
|
||||
shm.close()
|
||||
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
|
||||
|
||||
@property
|
||||
def ave_tokens(
|
||||
self,
|
||||
):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
tmp = shared_value.value
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
return tmp
|
||||
|
||||
def ave_tokens_update(self, length):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
if shared_value.value < 0:
|
||||
shared_value.value = float(length)
|
||||
else:
|
||||
shared_value.value = 0.98 * shared_value.value + 0.02 * length
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
|
||||
def size(self):
|
||||
return self.dataset.size()
|
||||
|
||||
def __iter__(self):
|
||||
self.iterate()
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.exhausted = False
|
||||
if self.iterator is not None:
|
||||
self.iterator.close()
|
||||
self.iterator = None
|
||||
self.used = BitSet()
|
||||
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||
|
||||
def transform(self, data: dict) -> dict:
|
||||
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||
weight = weight / weight.sum()
|
||||
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||
return self.transform_func(data, num_incontext, random.Random())
|
||||
|
||||
def segment_iterate(self, sample_iter):
|
||||
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||
self.ave_tokens_update(len(src_ids)) # 0 for input ids
|
||||
yield src_ids, tgt_ids, index
|
||||
|
||||
def iterate(self):
|
||||
# make the dataset itself an iterator
|
||||
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||
self.iterator = self.segment_iterate(sample_iter)
|
||||
|
||||
def __next__(self):
|
||||
# advance the task iterator
|
||||
if self.iterator is None:
|
||||
self.iterate()
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if state_dict.get("exhausted", False):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
used = state_dict.get("used", BitSet())
|
||||
if len(used) == len(self.dataset):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
self.exhausted = False
|
||||
self.used = used
|
||||
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
|
||||
|
||||
def state_dict(self):
|
||||
if len(self.used) == len(self.dataset):
|
||||
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
|
||||
else:
|
||||
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
|
||||
|
||||
def update_state(self, indice):
|
||||
self.used.update(indice)
|
||||
|
||||
|
||||
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str,
|
||||
cfg_json_str,
|
||||
tokenizer,
|
||||
max_length,
|
||||
weight_by_size: bool = True,
|
||||
nthreads=5,
|
||||
prefetch_slice=100,
|
||||
parallel_loading=False,
|
||||
vdc_sampling=False,
|
||||
update_weights_frequency=1,
|
||||
seed=42,
|
||||
):
|
||||
super(MixedIndexedDataset, self).__init__()
|
||||
self.set_seed(seed + bmt.rank())
|
||||
self.weight_by_size = weight_by_size
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_id
|
||||
self.bos_token_id = self.tokenizer.bos_id
|
||||
self.path2transform = dict()
|
||||
self.task_dict = OrderedDict()
|
||||
self.nthreads = nthreads
|
||||
self.prefetch_slice = prefetch_slice
|
||||
# useful for indexing
|
||||
self.tasks = []
|
||||
self.names = []
|
||||
# ending of iteration
|
||||
self.remain = 0
|
||||
self.max_length = max_length
|
||||
self.vdc_sampling = vdc_sampling
|
||||
if self.vdc_sampling:
|
||||
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
|
||||
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||
|
||||
self.update_weights_frequency = update_weights_frequency
|
||||
|
||||
self.path2transform = dict()
|
||||
|
||||
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||
|
||||
if parallel_loading:
|
||||
self.parallel_load(cfgs, max_workers=None)
|
||||
else:
|
||||
self.sequential_load(cfgs)
|
||||
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def set_seed(self, seed):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def load_task(self, cfg):
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
return task
|
||||
|
||||
def sequential_load(self, cfgs):
|
||||
self.cfgs = cfgs
|
||||
for cfg in cfgs:
|
||||
# python3.7 and later preserves insertion order to dictionary
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
self.task_dict[task.task_name] = task
|
||||
self.tasks.append(task)
|
||||
self.names.append(task.task_name)
|
||||
self.remain += 1
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
missing_keys = []
|
||||
for name, task in self.task_dict.items():
|
||||
if name in state_dict:
|
||||
task.load_state_dict(state_dict[name])
|
||||
else:
|
||||
missing_keys.append(name)
|
||||
self.update_weights()
|
||||
return missing_keys
|
||||
|
||||
def save_state_dict(self, path):
|
||||
state_dict = {}
|
||||
for name, task in self.task_dict.items():
|
||||
_state_dict = task.state_dict()
|
||||
if isinstance(_state_dict["used"], BitSet):
|
||||
bitset = _state_dict["used"]
|
||||
_file_name = bitset.save(path)
|
||||
_state_dict["used"] = _file_name
|
||||
state_dict[name] = _state_dict
|
||||
else:
|
||||
state_dict[name] = task.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
logger.info("Dataset state saved")
|
||||
|
||||
def update_states(self, task_ids, indice):
|
||||
is_dict = isinstance(indice, dict)
|
||||
uniq = torch.unique(task_ids)
|
||||
for idx in uniq:
|
||||
idx = idx.item()
|
||||
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||
self.tasks[idx].update_state(indexes)
|
||||
|
||||
def get_transform_func(self, module_name: str, transform_script_path):
|
||||
if transform_script_path is None:
|
||||
# allow null transform
|
||||
return lambda data, num_incontext, rand: data
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||
if transform_script_path not in self.path2transform:
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
if spec is None:
|
||||
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.path2transform[transform_script_path] = {
|
||||
"loader": loader,
|
||||
"module": mod,
|
||||
"last_mtime": 0,
|
||||
}
|
||||
transform_script_info = self.path2transform[transform_script_path]
|
||||
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||
if curr_mtime > transform_script_info["last_mtime"]:
|
||||
transform_script_info["last_mtime"] = curr_mtime
|
||||
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||
if transform_func is None:
|
||||
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||
return transform_func
|
||||
|
||||
def update_weights(self):
|
||||
task0 = self.tasks[0]
|
||||
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||
weights = []
|
||||
for task in self.tasks:
|
||||
if task.exhausted:
|
||||
weights.append(0)
|
||||
else:
|
||||
if task.ave_tokens == -1:
|
||||
weights.append(task.abs_weight / self.max_length)
|
||||
else:
|
||||
weights.append(task.abs_weight / task.ave_tokens)
|
||||
weights = np.array(weights)
|
||||
else:
|
||||
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||
if self.weight_by_size:
|
||||
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||
weights *= sizes
|
||||
self.weights = weights / weights.sum()
|
||||
|
||||
def __iter__(self):
|
||||
for task in self.tasks:
|
||||
task.iterate()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
step = 1
|
||||
while True:
|
||||
if self.remain == 0:
|
||||
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||
raise StopIteration
|
||||
if self.vdc_sampling:
|
||||
idx = next(self.vdc_gen)(self.weights)
|
||||
else:
|
||||
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||
|
||||
data = next(self.tasks[idx])
|
||||
if step % self.update_weights_frequency == 0:
|
||||
self.update_weights()
|
||||
if data is None:
|
||||
if self.tasks[idx].allow_repeat:
|
||||
# _runtime_ave = self.tasks[idx].ave_tokens
|
||||
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
# self.tasks[idx] = SegmentedDataset(
|
||||
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||
# )
|
||||
# self.tasks[idx].ave_tokens_update(_runtime_ave)
|
||||
self.tasks[idx].reset()
|
||||
else:
|
||||
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
self.tasks[idx].exhaust = True
|
||||
self.remain -= 1
|
||||
continue
|
||||
step += 1
|
||||
|
||||
return dict(
|
||||
task_id=idx,
|
||||
input=data[0],
|
||||
target=data[1],
|
||||
index=data[2],
|
||||
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||
)
|
||||
|
||||
|
||||
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||
self.max_total_length = batch_size * max_length
|
||||
self.batch_size = 1
|
||||
# setting compact=True concats segments orignated from the same input
|
||||
# into a long sequence. the relative order of segments should be preserved
|
||||
# in mixed_dataset, e.g.,
|
||||
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||
self.compact = compact
|
||||
|
||||
self.total_length = 0
|
||||
self.task2seqs = defaultdict(list)
|
||||
self.mixed_dataset = mixed_dataset
|
||||
self._max_length = max_length
|
||||
self._pose_prob = pose_prob
|
||||
self._pose_scaling_factor = pose_scaling_factor
|
||||
if self._pose_prob > 0.0:
|
||||
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||
else:
|
||||
self._scaled_max_length = max_length
|
||||
|
||||
def put(self, sample):
|
||||
self.total_length += len(sample["target"])
|
||||
task_id = sample["task_id"]
|
||||
if self.compact and self.task2seqs[task_id]:
|
||||
last = self.task2seqs[task_id][-1]
|
||||
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||
# concatenate sequantial segments for longer context modeling: why not?
|
||||
last["input"].extend(sample["input"])
|
||||
last["target"].extend(sample["target"])
|
||||
return
|
||||
self.task2seqs[task_id].append(sample)
|
||||
|
||||
def _pose_preprocess(
|
||||
self,
|
||||
input_ids: NDArray[np.int32],
|
||||
) -> NDArray[np.int32]:
|
||||
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||
"""
|
||||
len_chunk = min(len(input_ids), self._max_length)
|
||||
len_input = len(input_ids)
|
||||
# Chunk input randomly to fit max_length if needed
|
||||
lt1 = 0
|
||||
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||
# Generate PoSE position ids
|
||||
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||
len_position_ids = len(position_ids)
|
||||
lt = 0
|
||||
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||
position_ids[: rt1 - lt1] += lt
|
||||
position_ids[rt1 - lt1 :] += rt
|
||||
return position_ids
|
||||
|
||||
def pop(self):
|
||||
indexes = defaultdict(list)
|
||||
lengths = []
|
||||
|
||||
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
|
||||
span_begin = 0
|
||||
for samples in self.task2seqs.values():
|
||||
while samples:
|
||||
sample = samples.pop()
|
||||
span_end = span_begin + len(sample["input"])
|
||||
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||
else:
|
||||
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||
lengths.append(len(sample["target"]))
|
||||
indexes[int(sample["task_id"])].append(sample["index"])
|
||||
self.total_length -= len(sample["target"])
|
||||
span_begin = span_end
|
||||
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||
dim=0,
|
||||
).int()
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"targets": targets,
|
||||
"task_ids": task_ids,
|
||||
"indexes": indexes,
|
||||
# adhere to flash attention interface
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||
"lengths": torch.tensor(sum(lengths)).int(),
|
||||
"task_names": self.mixed_dataset.names,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return batch
|
||||
|
||||
def will_be_full(self, sample):
|
||||
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.mixed_dataset:
|
||||
if self.will_be_full(sample):
|
||||
yield self.pop()
|
||||
self.put(sample)
|
||||
|
||||
|
||||
class CudaPrefetcher(Iterable):
|
||||
"""
|
||||
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||
self.loader = iter(loader)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
if self.tp_rank == 0:
|
||||
data = next(self.loader)
|
||||
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||
d = {}
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
np_cur_data = data[key].cpu().numpy()
|
||||
bs = np_cur_data.tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||
elif isinstance(data[key], np.ndarray):
|
||||
bs = data[key].tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||
else:
|
||||
d[key] = data[key]
|
||||
try:
|
||||
_ = json.dumps(d)
|
||||
except TypeError:
|
||||
print(d)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||
json.dump(d, f)
|
||||
bmt.synchronize()
|
||||
if self.tp_rank != 0:
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||
data = json.load(f)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||
bs = fb.read()
|
||||
offset = 0
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||
data[key][3:]
|
||||
)
|
||||
offset = nw_offset
|
||||
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = torch.from_numpy(
|
||||
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||
.reshape(data[key][3:])
|
||||
.copy()
|
||||
)
|
||||
offset = nw_offset
|
||||
self.data = data
|
||||
else:
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
with torch.cuda.stream(self.stream):
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key].record_stream(torch.cuda.current_stream())
|
||||
data = copy.deepcopy(self.data)
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
|
@ -0,0 +1,827 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from tenacity import retry
|
||||
from tenacity import stop_after_attempt
|
||||
from tenacity import wait_fixed
|
||||
from tenacity import wait_random
|
||||
|
||||
from fm9g.dataset import PrefetchDecodeDataset
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.vdc_sampling import van_der_corput
|
||||
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||
|
||||
from .flask_ps import app as flask_ps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
IGNORE_TGT = -100
|
||||
|
||||
|
||||
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||
if cfg_json_str is not None:
|
||||
cfgs = json.loads(cfg_json_str)
|
||||
else:
|
||||
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||
cfgs = json.load(fin)
|
||||
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||
|
||||
path_dict = None
|
||||
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||
try:
|
||||
with open(platform_config_path, "r") as f:
|
||||
platform_cfg = json.load(f)
|
||||
path_dict = platform_cfg["dataset_map"]
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||
except Exception as e:
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||
|
||||
task_name2dataset_name = dict()
|
||||
for idx, cfg in enumerate(cfgs):
|
||||
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||
# to be delibrately annoying :)
|
||||
if cfg["task_name"] in task_name2dataset_name:
|
||||
raise ValueError(
|
||||
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||
)
|
||||
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||
|
||||
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||
# if path_dict is not None:
|
||||
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||
|
||||
# dealing with optional configs
|
||||
if "weight" in cfg:
|
||||
assert isinstance(cfg["weight"], (float, int))
|
||||
else:
|
||||
cfg["weight"] = 1.0
|
||||
|
||||
if "oversize_rule" in cfg:
|
||||
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||
else:
|
||||
cfg["oversize_rule"] = "segment"
|
||||
|
||||
if "transforms" in cfg:
|
||||
assert isinstance(cfg["transforms"], str)
|
||||
# dealing with relative path
|
||||
if not cfg["transforms"].startswith("/"):
|
||||
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||
if not cfg["transforms"]:
|
||||
cfg["transforms"] = None
|
||||
else:
|
||||
cfg["transforms"] = None
|
||||
|
||||
if "incontext_weight" in cfg:
|
||||
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||
else:
|
||||
cfg["incontext_weight"] = [1.0]
|
||||
cfg["id"] = idx
|
||||
# dataset and iterator will be built
|
||||
return cfgs
|
||||
|
||||
|
||||
def data2ids(data, tokenizer, max_length):
|
||||
text = "\n".join(
|
||||
[
|
||||
data.get("title", "").strip(),
|
||||
data.get("question", "").strip(),
|
||||
data.get("answer", "").strip(),
|
||||
data.get("abstract", "").strip(),
|
||||
data.get("text", "").strip(),
|
||||
data.get("code", "").strip(),
|
||||
]
|
||||
).strip()
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||
yield from ()
|
||||
return
|
||||
# suppress the annoying warning from tokenizer
|
||||
ids = (
|
||||
[tokenizer.bos_token_id]
|
||||
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||
+ [tokenizer.eos_token_id]
|
||||
)
|
||||
src_ids = ids[0:-1]
|
||||
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||
|
||||
if len(src_ids) > max_length:
|
||||
for st in range(0, len(src_ids), max_length):
|
||||
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||
else:
|
||||
yield src_ids, tgt_ids
|
||||
|
||||
|
||||
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||
assert oversize_rule in ("drop", "head", "segment")
|
||||
if data is None:
|
||||
yield from ()
|
||||
return
|
||||
if "output" not in data or not data["output"]:
|
||||
yield from ()
|
||||
return
|
||||
if "input" not in data:
|
||||
data["input"] = ""
|
||||
|
||||
src_ids = [tokenizer.bos_token_id]
|
||||
tgt_ids = []
|
||||
has_input = False
|
||||
is_segment_reenter = False
|
||||
|
||||
# Use incremental tokenization to avoid waiting for a long document
|
||||
MAX_CHUNK_LENGTH = max_length * 10
|
||||
for part in ("input", "output"):
|
||||
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||
while l < len(data[part]):
|
||||
current_slice = data[part][l:r]
|
||||
if not current_slice:
|
||||
break
|
||||
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||
if part == "input":
|
||||
if len(token_ids) > 0:
|
||||
has_input = True
|
||||
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||
yield from ()
|
||||
return
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||
l = r
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
else:
|
||||
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||
if oversize_rule == "drop":
|
||||
yield from ()
|
||||
return
|
||||
elif oversize_rule == "head":
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
elif oversize_rule == "segment":
|
||||
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||
if has_input: # is instruction data
|
||||
if (
|
||||
do_compact
|
||||
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||
and instruction_rest_space / len(src_ids) > 0.8
|
||||
): # can be squeezed into max length
|
||||
inputs_len = len(src_ids)
|
||||
keep_len = instruction_rest_space // 2
|
||||
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else: # else use head rule
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
else: # normal segment
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids)
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
src_ids = src_ids[max_length:]
|
||||
tgt_ids = tgt_ids[max_length:]
|
||||
# sliding input str window
|
||||
consumed_str = tokenizer.decode(selected_token_ids)
|
||||
l += len(consumed_str)
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
is_segment_reenter = True
|
||||
else:
|
||||
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||
not is_segment_reenter and len(token_ids) > 0
|
||||
): # is segmented LM data
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else:
|
||||
yield from ()
|
||||
return
|
||||
|
||||
|
||||
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
tokenizer,
|
||||
max_length=1024,
|
||||
transform_func=None,
|
||||
nthreads=1,
|
||||
prefetch_slice=3,
|
||||
slice_size=500,
|
||||
do_compact=False,
|
||||
is_local=False,
|
||||
):
|
||||
def get_full_qualified_name(func):
|
||||
module_name = func.__module__
|
||||
qual_name = func.__qualname__
|
||||
return f"{module_name}.{qual_name}"
|
||||
|
||||
super(SegmentedDataset, self).__init__()
|
||||
self.segment = functools.partial(
|
||||
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||
)
|
||||
self.cfg = cfg
|
||||
self.max_length = max_length
|
||||
self.nthreads = nthreads
|
||||
self.transform_func = transform_func
|
||||
self.prefetch_slice = prefetch_slice
|
||||
self.slice_size = slice_size
|
||||
self.abs_weight = cfg.get("abs_weight", None)
|
||||
self.task_name = cfg["task_name"]
|
||||
self.dataset_name = cfg["dataset_name"]
|
||||
self.oversize_rule = cfg["oversize_rule"]
|
||||
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", False))
|
||||
self.exhausted = False
|
||||
self.iterator = None
|
||||
|
||||
self.counter = 0
|
||||
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||
self.used = set()
|
||||
self.is_local = is_local
|
||||
self.port_offset = random.randint(1, 8000) + 1000
|
||||
if is_local or bmt.rank() == 0:
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
_avg_tokens = cfg.get("ave_tokens_per_line", -1)
|
||||
_avg_tokens = cfg.get("avg_tokens", _avg_tokens)
|
||||
requests.post(f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={_avg_tokens}")
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||
def set_avg_tokens(self, avg_tokens):
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={avg_tokens}"
|
||||
response = requests.post(url)
|
||||
if response.status_code != 200:
|
||||
self.reset_port_offset()
|
||||
print(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
|
||||
raise RuntimeError(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||
def update_avg_tokens_by_ema(self, length):
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=update&length={length}"
|
||||
response = requests.post(url)
|
||||
if response.status_code != 200:
|
||||
self.reset_port_offset()
|
||||
print(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
|
||||
raise RuntimeError(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
|
||||
|
||||
@property
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||
def avg_tokens(self):
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}"
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
self.reset_port_offset()
|
||||
print(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
|
||||
raise RuntimeError(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
|
||||
data = response.json()
|
||||
return data["avg_tokens"]
|
||||
|
||||
def reset_port_offset(self):
|
||||
self.port_offset = random.randint(1, 8000) + 1000
|
||||
|
||||
def size(self):
|
||||
return self.dataset.size()
|
||||
|
||||
def __iter__(self):
|
||||
self.iterate()
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.exhausted = False
|
||||
if self.iterator is not None:
|
||||
self.iterator.close()
|
||||
self.iterator = None
|
||||
self.used = BitSet()
|
||||
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||
|
||||
def transform(self, data: dict) -> dict:
|
||||
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||
weight = weight / weight.sum()
|
||||
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||
return self.transform_func(data, num_incontext, random.Random())
|
||||
|
||||
def segment_iterate(self, sample_iter):
|
||||
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||
self.update_avg_tokens_by_ema(len(src_ids)) # 0 for input ids
|
||||
yield src_ids, tgt_ids, index
|
||||
|
||||
def iterate(self):
|
||||
# make the dataset itself an iterator
|
||||
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||
self.iterator = self.segment_iterate(sample_iter)
|
||||
|
||||
def __next__(self):
|
||||
# advance the task iterator
|
||||
if self.iterator is None:
|
||||
self.iterate()
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if state_dict.get("exhausted", False):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
used = state_dict.get("used", BitSet())
|
||||
if len(used) == len(self.dataset):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
self.exhausted = False
|
||||
self.used = used
|
||||
_avg_tokens = state_dict.get("ave_tokens", -1)
|
||||
_avg_tokens = state_dict.get("avg_tokens", _avg_tokens)
|
||||
if self.avg_tokens == -1 or self.avg_tokens < _avg_tokens:
|
||||
self.set_avg_tokens(_avg_tokens)
|
||||
|
||||
def state_dict(self):
|
||||
if len(self.used) == len(self.dataset):
|
||||
return dict(exhausted=True, used=BitSet(), avg_tokens=self.avg_tokens)
|
||||
else:
|
||||
return dict(exhausted=False, used=self.used, avg_tokens=self.avg_tokens)
|
||||
|
||||
def update_state(self, indice):
|
||||
self.used.update(indice)
|
||||
|
||||
|
||||
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str,
|
||||
cfg_json_str,
|
||||
tokenizer,
|
||||
max_length,
|
||||
weight_by_size: bool = True,
|
||||
nthreads=5,
|
||||
prefetch_slice=100,
|
||||
parallel_loading=False,
|
||||
vdc_sampling=False,
|
||||
update_weights_frequency=1,
|
||||
seed=42,
|
||||
):
|
||||
if bmt.rank() == 0:
|
||||
port = int(os.environ["MASTER_PORT"]) + 2188
|
||||
self.flask_ps_proc = mp.Process(target=flask_ps.run, kwargs={"host": "0.0.0.0", "port": port})
|
||||
self.flask_ps_proc.start()
|
||||
super(MixedIndexedDataset, self).__init__()
|
||||
self.set_seed(seed + bmt.rank())
|
||||
self.weight_by_size = weight_by_size
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.bos_token_id = self.tokenizer.bos_token_id
|
||||
self.path2transform = dict()
|
||||
self.task_dict = OrderedDict()
|
||||
self.nthreads = nthreads
|
||||
self.prefetch_slice = prefetch_slice
|
||||
# useful for indexing
|
||||
self.tasks = []
|
||||
self.names = []
|
||||
# ending of iteration
|
||||
self.remain = 0
|
||||
self.max_length = max_length
|
||||
self.vdc_sampling = vdc_sampling
|
||||
if self.vdc_sampling:
|
||||
self._vdc_values = [van_der_corput(i) for i in range(100000)]
|
||||
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||
|
||||
self.update_weights_frequency = update_weights_frequency
|
||||
|
||||
self.path2transform = dict()
|
||||
|
||||
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||
|
||||
if parallel_loading:
|
||||
self.parallel_load(cfgs, max_workers=None)
|
||||
else:
|
||||
self.sequential_load(cfgs)
|
||||
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def set_seed(self, seed):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def load_task(self, cfg):
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
return task
|
||||
|
||||
def sequential_load(self, cfgs):
|
||||
self.cfgs = cfgs
|
||||
for cfg in cfgs:
|
||||
# python3.7 and later preserves insertion order to dictionary
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
self.task_dict[task.task_name] = task
|
||||
self.tasks.append(task)
|
||||
self.names.append(task.task_name)
|
||||
self.remain += 1
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
missing_keys = []
|
||||
for name, task in self.task_dict.items():
|
||||
if name in state_dict:
|
||||
task.load_state_dict(state_dict[name])
|
||||
else:
|
||||
missing_keys.append(name)
|
||||
self.update_weights()
|
||||
return missing_keys
|
||||
|
||||
def save_state_dict(self, path):
|
||||
state_dict = {}
|
||||
for name, task in self.task_dict.items():
|
||||
_state_dict = task.state_dict()
|
||||
if isinstance(_state_dict["used"], BitSet):
|
||||
bitset = _state_dict["used"]
|
||||
_file_name = bitset.save(path)
|
||||
_state_dict["used"] = _file_name
|
||||
state_dict[name] = _state_dict
|
||||
else:
|
||||
state_dict[name] = task.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
logger.info("Dataset state saved")
|
||||
|
||||
def update_states(self, task_ids, indice):
|
||||
is_dict = isinstance(indice, dict)
|
||||
uniq = torch.unique(task_ids)
|
||||
for idx in uniq:
|
||||
idx = idx.item()
|
||||
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||
self.tasks[idx].update_state(indexes)
|
||||
|
||||
def get_transform_func(self, module_name: str, transform_script_path):
|
||||
if transform_script_path is None:
|
||||
# allow null transform
|
||||
return lambda data, num_incontext, rand: data
|
||||
if "/" in module_name:
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name.split("/")[-1])
|
||||
else:
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||
if transform_script_path not in self.path2transform:
|
||||
# loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||
# spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
spec = importlib.util.spec_from_file_location(module_name, transform_script_path)
|
||||
if spec is None:
|
||||
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.path2transform[transform_script_path] = {
|
||||
"module": mod,
|
||||
"last_mtime": 0,
|
||||
}
|
||||
transform_script_info = self.path2transform[transform_script_path]
|
||||
curr_mtime = float(os.path.getmtime(transform_script_path))
|
||||
if curr_mtime > transform_script_info["last_mtime"]:
|
||||
transform_script_info["last_mtime"] = curr_mtime
|
||||
# load module
|
||||
spec.loader.exec_module(transform_script_info["module"])
|
||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||
if transform_func is None:
|
||||
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||
return transform_func
|
||||
|
||||
def update_weights(self):
|
||||
task0 = self.tasks[0]
|
||||
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||
weights = []
|
||||
for task in self.tasks:
|
||||
if task.exhausted:
|
||||
weights.append(0)
|
||||
else:
|
||||
if task.avg_tokens == -1:
|
||||
weights.append(task.abs_weight / self.max_length)
|
||||
else:
|
||||
weights.append(task.abs_weight / task.avg_tokens)
|
||||
weights = np.array(weights)
|
||||
else:
|
||||
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||
if self.weight_by_size:
|
||||
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||
weights *= sizes
|
||||
self.weights = weights / weights.sum()
|
||||
|
||||
def __iter__(self):
|
||||
for task in self.tasks:
|
||||
task.iterate()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
step = 1
|
||||
while True:
|
||||
if self.remain == 0:
|
||||
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||
raise StopIteration
|
||||
if self.vdc_sampling:
|
||||
idx = next(self.vdc_gen)(self.weights)
|
||||
else:
|
||||
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||
|
||||
data = next(self.tasks[idx])
|
||||
if step % self.update_weights_frequency == 0:
|
||||
self.update_weights()
|
||||
if data is None:
|
||||
if self.tasks[idx].allow_repeat:
|
||||
# _runtime_ave = self.tasks[idx].avg_tokens
|
||||
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
# self.tasks[idx] = SegmentedDataset(
|
||||
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||
# )
|
||||
# self.tasks[idx].avg_tokens_update(_runtime_ave)
|
||||
self.tasks[idx].reset()
|
||||
else:
|
||||
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
self.tasks[idx].exhaust = True
|
||||
self.remain -= 1
|
||||
continue
|
||||
step += 1
|
||||
return dict(
|
||||
task_id=idx,
|
||||
input=data[0],
|
||||
target=data[1],
|
||||
index=data[2],
|
||||
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||
)
|
||||
|
||||
|
||||
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||
self.max_total_length = batch_size * max_length
|
||||
self.batch_size = 1
|
||||
# setting compact=True concats segments orignated from the same input
|
||||
# into a long sequence. the relative order of segments should be preserved
|
||||
# in mixed_dataset, e.g.,
|
||||
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||
self.compact = compact
|
||||
|
||||
self.total_length = 0
|
||||
self.task2seqs = defaultdict(list)
|
||||
self.mixed_dataset = mixed_dataset
|
||||
self._max_length = max_length
|
||||
self._pose_prob = pose_prob
|
||||
self._pose_scaling_factor = pose_scaling_factor
|
||||
if self._pose_prob > 0.0:
|
||||
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||
else:
|
||||
self._scaled_max_length = max_length
|
||||
|
||||
def put(self, sample):
|
||||
self.total_length += len(sample["target"])
|
||||
task_id = sample["task_id"]
|
||||
if self.compact and self.task2seqs[task_id]:
|
||||
last = self.task2seqs[task_id][-1]
|
||||
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||
# concatenate sequantial segments for longer context modeling: why not?
|
||||
last["input"].extend(sample["input"])
|
||||
last["target"].extend(sample["target"])
|
||||
return
|
||||
self.task2seqs[task_id].append(sample)
|
||||
|
||||
def _pose_preprocess(
|
||||
self,
|
||||
input_ids: NDArray[np.int32],
|
||||
) -> NDArray[np.int32]:
|
||||
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||
"""
|
||||
len_chunk = min(len(input_ids), self._max_length)
|
||||
len_input = len(input_ids)
|
||||
# Chunk input randomly to fit max_length if needed
|
||||
lt1 = 0
|
||||
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||
# Generate PoSE position ids
|
||||
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||
len_position_ids = len(position_ids)
|
||||
lt = 0
|
||||
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||
position_ids[: rt1 - lt1] += lt
|
||||
position_ids[rt1 - lt1 :] += rt
|
||||
return position_ids
|
||||
|
||||
def pop(self):
|
||||
indexes = defaultdict(list)
|
||||
lengths = []
|
||||
|
||||
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
|
||||
span_begin = 0
|
||||
for samples in self.task2seqs.values():
|
||||
while samples:
|
||||
sample = samples.pop()
|
||||
span_end = span_begin + len(sample["input"])
|
||||
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||
else:
|
||||
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||
lengths.append(len(sample["target"]))
|
||||
indexes[int(sample["task_id"])].append(sample["index"])
|
||||
self.total_length -= len(sample["target"])
|
||||
span_begin = span_end
|
||||
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||
dim=0,
|
||||
).int()
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"targets": targets,
|
||||
"task_ids": task_ids,
|
||||
"indexes": indexes,
|
||||
# adhere to flash attention interface
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||
"lengths": torch.tensor(sum(lengths)).int(),
|
||||
"task_names": self.mixed_dataset.names,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return batch
|
||||
|
||||
def will_be_full(self, sample):
|
||||
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.mixed_dataset:
|
||||
if self.will_be_full(sample):
|
||||
yield self.pop()
|
||||
self.put(sample)
|
||||
|
||||
|
||||
class CudaPrefetcher(Iterable):
|
||||
"""
|
||||
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||
self.loader = iter(loader)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
if self.tp_rank == 0:
|
||||
data = next(self.loader)
|
||||
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||
d = {}
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
np_cur_data = data[key].cpu().numpy()
|
||||
bs = np_cur_data.tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||
elif isinstance(data[key], np.ndarray):
|
||||
bs = data[key].tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||
else:
|
||||
d[key] = data[key]
|
||||
try:
|
||||
_ = json.dumps(d)
|
||||
except TypeError:
|
||||
print(d)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||
json.dump(d, f)
|
||||
bmt.synchronize()
|
||||
if self.tp_rank != 0:
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||
data = json.load(f)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||
bs = fb.read()
|
||||
offset = 0
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||
data[key][3:]
|
||||
)
|
||||
offset = nw_offset
|
||||
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = torch.from_numpy(
|
||||
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||
.reshape(data[key][3:])
|
||||
.copy()
|
||||
)
|
||||
offset = nw_offset
|
||||
self.data = data
|
||||
else:
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
with torch.cuda.stream(self.stream):
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
data = copy.deepcopy(self.data)
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
|
@ -12,3 +12,4 @@ from .position_embedding import RotaryEmbedding
|
|||
from .position_embedding import RotaryEmbeddingESM
|
||||
from .position_embedding import SegmentPositionEmbedding
|
||||
from .transformer import Encoder
|
||||
#from _attention_pp_sp import OpAttnPipeSP
|
|
@ -0,0 +1,79 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import bmtrain as bmt
|
||||
|
||||
def _linear_backward(grad_output, x, weight, bias):
|
||||
grad_x = grad_weight = grad_bias = None
|
||||
if x.requires_grad:
|
||||
grad_x = grad_output.matmul(weight)
|
||||
if weight.requires_grad:
|
||||
grad_weight = grad_output.reshape(-1,
|
||||
grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
|
||||
if bias is not None and bias.requires_grad:
|
||||
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
|
||||
return grad_x, grad_weight, grad_bias
|
||||
|
||||
class OpAttnPipeSP(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q_w, k_w, v_w, q_b, w_b, v_b, x, cache_kv, cache_kv_inp, cu_seqlens_q, cu_seqlens_k, max_seqlen):
|
||||
ctx.save_for_backward(x, q_w, k_w, v_w, q_b, w_b, v_b)
|
||||
if cache_kv.numel() = 0:
|
||||
q = F.linear(x, q_w, q_b)
|
||||
k = F.linear(x, k_w, w_b)
|
||||
v = F.linear(x, v_w, v_b)
|
||||
else:
|
||||
q = F.linear(x, q_w, q_b)
|
||||
k = F.linear(x, k_w, w_b)
|
||||
v = F.linear(x, v_w, v_b)
|
||||
k = torch.cat([cache_kv[0], k], dim=1)
|
||||
v = torch.cat([cache_kv[1], v], dim=1)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen, max_seqlen, 0, causal=True, window_size=(-1,-1), alibi_slopes=None, deterministic=False, return_attn_probs=False
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
||||
)
|
||||
ctx.max_seqlen_q = max_seqlen
|
||||
ctx.max_seqlen_k = max_seqlen
|
||||
|
||||
|
||||
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_varlen_backward(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
ctx.max_seqlen_q,
|
||||
ctx.max_seqlen_k,
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
False,
|
||||
(-1,-1),
|
||||
None,
|
||||
False,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., : dout.shape[-1]]
|
||||
dv = dv[..., : dout.shape[-1]]
|
||||
|
||||
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
|
||||
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
|
||||
d_xk, d_wk, d_bk = _linear_backward(dk, x, k_w, k_b)
|
||||
|
||||
|
||||
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
|
|
@ -13,13 +13,13 @@ from einops import rearrange
|
|||
from .linear import ColumnParallelLinear
|
||||
from .linear import Linear
|
||||
from .position_embedding import apply_chatglm_rotary_pos_emb
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
|
||||
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
except:
|
||||
flash_attn_varlen_func = None
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
#try:
|
||||
# from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
|
||||
# from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
|
||||
# from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
#except:
|
||||
# flash_attn_varlen_func = None
|
||||
|
||||
try:
|
||||
from flash_attn.bert_padding import pad_input
|
||||
|
@ -54,6 +54,8 @@ class OpFlash(torch.autograd.Function):
|
|||
dropout_p,
|
||||
ctx.softmax_scale,
|
||||
causal=causal,
|
||||
window_size=(-1, -1),
|
||||
alibi_slopes=None,
|
||||
return_softmax=False,
|
||||
)
|
||||
if record:
|
||||
|
@ -85,6 +87,9 @@ class OpFlash(torch.autograd.Function):
|
|||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
(-1,-1),
|
||||
None,
|
||||
False,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
return None, None, dq, dk, dv, None, None, None, None
|
||||
|
@ -294,12 +299,28 @@ class Attention(bmt.DistributedModule):
|
|||
h_q, h_k = position_bias(
|
||||
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
|
||||
)
|
||||
# score = flash_attn_varlen_func(
|
||||
# h_q, h_k, h_v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.dropout_p, causal=True
|
||||
# )
|
||||
score = OpFlash.apply(
|
||||
self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
|
||||
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
||||
print("h_q: ", h_q)
|
||||
print("cu_seqlens: ", cu_seqlens)
|
||||
print("max_seqlen: ", max_seqlen)
|
||||
score = flash_attn_varlen_func(
|
||||
h_q,
|
||||
h_k,
|
||||
h_v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
self.dropout_p,
|
||||
causal=True,
|
||||
deterministic=True,
|
||||
)
|
||||
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
||||
# Rongqiao change
|
||||
#score = OpFlash.apply(
|
||||
# self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
|
||||
#)
|
||||
|
||||
|
||||
score = score.view(batch_size, len_q, -1)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class EMAValue(object):
|
||||
def __init__(self, init_value: Optional[float] = None, decay_factor: float = 0.999) -> None:
|
||||
super().__init__()
|
||||
self._value = init_value
|
||||
self._decay_factor = decay_factor
|
||||
|
||||
@property
|
||||
def value(self) -> Optional[float]:
|
||||
return self._value
|
||||
|
||||
def update(self, value: float) -> None:
|
||||
if self._value is None:
|
||||
self._value = value
|
||||
else:
|
||||
self._value = self._decay_factor * self._value + (1 - self._decay_factor) * value
|
||||
|
||||
def update_with_return(self, value: float) -> Optional[float]:
|
||||
self.update(value)
|
||||
return self._value
|
|
@ -0,0 +1 @@
|
|||
from .fm9g import FM9GTokenizer
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue