Compare commits
No commits in common. "FM_9G" and "master" have entirely different histories.
|
@ -0,0 +1,16 @@
|
||||||
|
{
|
||||||
|
"vocab_size": 119696,
|
||||||
|
"dropout_p": 0.0,
|
||||||
|
"eps": 1e-05,
|
||||||
|
"half": true,
|
||||||
|
"use_flash_attn": true,
|
||||||
|
"flash_attn_mask_shape": "2d",
|
||||||
|
"dim_model": 4096,
|
||||||
|
"dim_ff": 12288,
|
||||||
|
"dim_head": 128,
|
||||||
|
"num_heads": 32,
|
||||||
|
"num_kv_heads": 32,
|
||||||
|
"num_layers": 48,
|
||||||
|
"activate_fn": "silu",
|
||||||
|
"scale": false
|
||||||
|
}
|
|
@ -119687,8 +119687,8 @@
|
||||||
"𠳐"
|
"𠳐"
|
||||||
"𥻗"
|
"𥻗"
|
||||||
"𬉼"
|
"𬉼"
|
||||||
"<|im_start|>"
|
"<pad_0>"
|
||||||
"<|im_end|>"
|
"<pad_1>"
|
||||||
"<pad_2>"
|
"<pad_2>"
|
||||||
"<pad_3>"
|
"<pad_3>"
|
||||||
"<pad_4>"
|
"<pad_4>"
|
|
@ -0,0 +1,14 @@
|
||||||
|
{
|
||||||
|
"vocab_size": 119696,
|
||||||
|
"dropout_p": 0.0,
|
||||||
|
"eps": 1e-05,
|
||||||
|
"half": true,
|
||||||
|
"dim_model": 4096,
|
||||||
|
"dim_ff": 11008,
|
||||||
|
"dim_head": 128,
|
||||||
|
"num_heads": 32,
|
||||||
|
"num_kv_heads": 32,
|
||||||
|
"num_layers": 32,
|
||||||
|
"activate_fn": "silu",
|
||||||
|
"scale": false
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,12 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"dataset_name": "wikipedia",
|
||||||
|
"task_name": "wikipedia",
|
||||||
|
"weight": 1.0,
|
||||||
|
"path": "path/to/data",
|
||||||
|
"incontext_weight": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"transforms": "/home/USERNAME/cpm9g/apps/cpm9g/config/datasets/wikipedia/script_cpmc.py"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,9 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def rand(n: int, r: random.Random):
|
||||||
|
return int(r.random() * n)
|
||||||
|
|
||||||
|
|
||||||
|
def transform(data, num_sample: int, r: random.Random):
|
||||||
|
return {"input": "", "output": data["text"]}
|
|
@ -0,0 +1,485 @@
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
||||||
|
from cpm.arguments import get_args
|
||||||
|
from cpm.cpm9g.models import CPM9G
|
||||||
|
from cpm.cpm9g.models import CPM9GConfig
|
||||||
|
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
||||||
|
from cpm.cpm9g.training_tasks import MixedDataset
|
||||||
|
from cpm.utils import allgather_objects
|
||||||
|
from cpm.utils import exporter
|
||||||
|
from cpm.utils import logger
|
||||||
|
from cpm.utils import LogManager
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(args):
|
||||||
|
tokenizer = CPM9GTokenizer(path=args.vocab)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(args):
|
||||||
|
config = CPM9GConfig.from_json_file(args.model_config)
|
||||||
|
config.tp = 1 if args.tp != 1 else 0
|
||||||
|
if args.flash == "none":
|
||||||
|
config.use_flash_attn = False
|
||||||
|
else:
|
||||||
|
config.use_flash_attn = True
|
||||||
|
if args.flash == "1d":
|
||||||
|
config.flash_attn_mask_shape = "1d"
|
||||||
|
else:
|
||||||
|
config.flash_attn_mask_shape = "2d"
|
||||||
|
if args.flash == "triton":
|
||||||
|
config.flash_impl = "triton"
|
||||||
|
elif args.flash == "cuda":
|
||||||
|
config.flash_impl = "cuda"
|
||||||
|
model = CPM9G(config)
|
||||||
|
if args.load is not None:
|
||||||
|
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||||
|
bmt.load(model, args.load)
|
||||||
|
else:
|
||||||
|
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||||
|
bmt.init_parameters(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer(args, model):
|
||||||
|
for name, para in model.named_parameters():
|
||||||
|
# if not ('input_embedding' in name or 'lm_head' in name):
|
||||||
|
# para.requires_grad_(False)
|
||||||
|
bmt.print_rank(name, para.requires_grad)
|
||||||
|
if args.offload:
|
||||||
|
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||||
|
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||||
|
if args.load is not None and args.load_grad:
|
||||||
|
start = time.time()
|
||||||
|
print(
|
||||||
|
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
|
||||||
|
== bmt.world_size()
|
||||||
|
):
|
||||||
|
file_name = os.path.join(
|
||||||
|
args.save,
|
||||||
|
args.save_name + "-{}.rank-{}.opt".format(args.start_step, bmt.rank()),
|
||||||
|
)
|
||||||
|
print(file_name)
|
||||||
|
if os.path.exists(file_name):
|
||||||
|
print("start to load grad ckpt {}".format(file_name))
|
||||||
|
states = torch.load(file_name)
|
||||||
|
optimizer.load_state_dict(states)
|
||||||
|
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||||
|
r"""
|
||||||
|
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||||
|
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_lr_warmup(self, num_iter) -> float:
|
||||||
|
return self.start_lr * num_iter / self.warmup_iter
|
||||||
|
|
||||||
|
def get_lr_decay(self, num_iter) -> float:
|
||||||
|
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||||
|
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||||
|
|
||||||
|
|
||||||
|
def get_learning_rate_scheduler(args, optimizer):
|
||||||
|
if args.lr_decay_iters is None:
|
||||||
|
args.lr_decay_iters = args.train_iters
|
||||||
|
# lr_scheduler = bmt.lr_scheduler.Noam(
|
||||||
|
lr_scheduler = Cosine(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=args.warmup_iters,
|
||||||
|
end_iter=args.lr_decay_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
)
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model_and_optimizer(args):
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
model = get_model(args)
|
||||||
|
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
optimizer = get_optimizer(args, model)
|
||||||
|
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
return tokenizer, model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
args = get_args(pretrain=True)
|
||||||
|
bmt.init_distributed(seed=args.seed, zero_level=3)
|
||||||
|
if args.save is not None:
|
||||||
|
os.makedirs(args.save, exist_ok=True)
|
||||||
|
if args.load is not None:
|
||||||
|
if args.start_step == 0:
|
||||||
|
args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def see_memory(detail=False):
|
||||||
|
if detail:
|
||||||
|
res = torch.cuda.memory_summary()
|
||||||
|
else:
|
||||||
|
res = (
|
||||||
|
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||||
|
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||||
|
)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def add_mem_time(info, mem_usage, tim_usage):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
bmt.synchronize()
|
||||||
|
mem_usage[info] = see_memory()
|
||||||
|
tim_usage[info] = time.time()
|
||||||
|
return mem_usage, tim_usage
|
||||||
|
|
||||||
|
|
||||||
|
class LossSpikeDetector:
|
||||||
|
def __init__(self, log_path: str) -> None:
|
||||||
|
self._last_loss: Dict[str, float] = {}
|
||||||
|
self._last_data: List[Any] = [None]
|
||||||
|
self._log_path = log_path
|
||||||
|
|
||||||
|
def update_data(self, data: Any):
|
||||||
|
self._last_data.append(data)
|
||||||
|
if len(self._last_data) > 2:
|
||||||
|
self._last_data = self._last_data[-2:]
|
||||||
|
|
||||||
|
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
||||||
|
loss_spike_result = []
|
||||||
|
for task, loss in loss_map.items():
|
||||||
|
if task in self._last_loss:
|
||||||
|
if loss > self._last_loss[task] * 3:
|
||||||
|
# loss spike!
|
||||||
|
loss_spike_result.append(
|
||||||
|
{
|
||||||
|
"prev": self._last_loss[task],
|
||||||
|
"curr": loss,
|
||||||
|
"task": task,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._last_loss[task] = float(loss)
|
||||||
|
if len(loss_spike_result) > 0:
|
||||||
|
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
||||||
|
|
||||||
|
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with open(self._log_path, "a", encoding="utf-8") as fp:
|
||||||
|
fp.write("=" * 20)
|
||||||
|
fp.write("\nloss spike at {}\n".format(iteration))
|
||||||
|
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
||||||
|
fp.write("data: \n")
|
||||||
|
for d in data:
|
||||||
|
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
||||||
|
fp.write("\n\n")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print("cannot output log to the file {}", self._log_path)
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain(
|
||||||
|
args,
|
||||||
|
tokenizer: CPM9GTokenizer,
|
||||||
|
model: CPM9G,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||||
|
):
|
||||||
|
average_time = bmt.utils.AverageRecorder()
|
||||||
|
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
||||||
|
optim_manager = bmt.optim.OptimManager(
|
||||||
|
loss_scale=None if args.bf16 else args.loss_scale,
|
||||||
|
loss_scale_steps=args.loss_scale_steps,
|
||||||
|
loss_scale_factor=2,
|
||||||
|
max_loss_scale=args.max_loss_scale,
|
||||||
|
min_loss_scale=args.min_loss_scale,
|
||||||
|
)
|
||||||
|
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
start_step = args.start_step
|
||||||
|
lsd = LossSpikeDetector("./log/debug/spile.%d.log" % bmt.rank())
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
import distutils.version # noqa: F401
|
||||||
|
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
if not os.path.exists(args.tensorboard):
|
||||||
|
os.makedirs(args.tensorboard)
|
||||||
|
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||||
|
|
||||||
|
if args.log_dir is not None and bmt.rank() == 0:
|
||||||
|
log_mgr = LogManager(args.log_dir)
|
||||||
|
|
||||||
|
global_token_pass = 0.0
|
||||||
|
global_world_size = bmt.world_size()
|
||||||
|
dataloader = MixedDataset(args.dataset, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"))
|
||||||
|
|
||||||
|
if args.load is not None:
|
||||||
|
dataset_states_path = args.load.replace(".pt", ".data")
|
||||||
|
if os.path.exists(dataset_states_path):
|
||||||
|
start = time.time()
|
||||||
|
bmt.print_rank("start to load data ckpt")
|
||||||
|
dataset_states = torch.load(dataset_states_path)
|
||||||
|
logger.info("load data ckpt in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
missing = dataloader.load_state_dict(dataset_states)
|
||||||
|
logger.info("load state dict in {:.2f}s".format(time.time() - start))
|
||||||
|
if len(missing) > 0:
|
||||||
|
bmt.print_rank("Missing keys when loading dataset states: ", missing)
|
||||||
|
else:
|
||||||
|
bmt.print_rank("cannot find data ckpt {}".format(dataset_states_path))
|
||||||
|
|
||||||
|
dataloader.start()
|
||||||
|
bmt.print_rank("finish dataset start")
|
||||||
|
try:
|
||||||
|
total = 0
|
||||||
|
hash = {}
|
||||||
|
for iteration, data in enumerate(dataloader):
|
||||||
|
iteration = iteration + start_step + 1
|
||||||
|
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||||
|
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
||||||
|
targets = torch.from_numpy(data["target"]).cuda().to(torch.int32)
|
||||||
|
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
||||||
|
task_names = data["task_names"]
|
||||||
|
lsd.update_data(data["raw_data"])
|
||||||
|
if args.flash == "cuda":
|
||||||
|
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
||||||
|
max_seqlen = data["max_seqlen"]
|
||||||
|
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
||||||
|
else:
|
||||||
|
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||||
|
input_context = torch.zeros_like(input_ids).cuda().bool()
|
||||||
|
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
mem_usage = {}
|
||||||
|
tim_usage = {}
|
||||||
|
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||||
|
# ===========
|
||||||
|
if args.flash == "cuda":
|
||||||
|
logits, _ = model(
|
||||||
|
input_ids,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, _ = model(
|
||||||
|
input_ids,
|
||||||
|
input_length,
|
||||||
|
input_context,
|
||||||
|
input_span,
|
||||||
|
)
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
||||||
|
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||||
|
global_loss = bmt.sum_loss(loss).item()
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||||
|
# ===========
|
||||||
|
optim_manager.backward(loss)
|
||||||
|
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||||
|
# ===========
|
||||||
|
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||||
|
optim_manager.step()
|
||||||
|
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||||
|
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||||
|
|
||||||
|
# ==========
|
||||||
|
iter_time = tim_usage["optim"] - tim_usage["init"]
|
||||||
|
average_time.record(iter_time)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
task_num = len(task_names)
|
||||||
|
targets_tmp = targets.expand(task_num, -1, -1)
|
||||||
|
task = torch.arange(task_num, dtype=torch.int32, device="cuda")[:, None, None]
|
||||||
|
targets_tmp = torch.where(
|
||||||
|
task_ids == task,
|
||||||
|
targets_tmp,
|
||||||
|
torch.scalar_tensor(-100, dtype=torch.int32, device="cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
task_loss_map: Dict[str, float] = {}
|
||||||
|
task_loss_tot: Dict[str, float] = {}
|
||||||
|
for i in range(task_num):
|
||||||
|
task_loss_map[task_names[i]] = loss_func(
|
||||||
|
logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1)
|
||||||
|
).item()
|
||||||
|
task_loss_tot[task_names[i]] = (targets_tmp[i, :].view(-1) >= 0).sum().float().item()
|
||||||
|
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
||||||
|
gatherd_task_loss_tot: List[Dict[str, float]] = allgather_objects(task_loss_tot)
|
||||||
|
|
||||||
|
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
||||||
|
global_task_loss_tot: Dict[str, Union[List[float], float]] = {}
|
||||||
|
|
||||||
|
for idx, local_task_loss_map in enumerate(gatherd_task_loss_map):
|
||||||
|
for task_name, task_loss in local_task_loss_map.items():
|
||||||
|
if task_name not in global_task_loss_map:
|
||||||
|
global_task_loss_map[task_name] = []
|
||||||
|
global_task_loss_map[task_name].append(task_loss)
|
||||||
|
for task_name, task_tot in gatherd_task_loss_tot[idx].items():
|
||||||
|
if task_name not in global_task_loss_tot:
|
||||||
|
global_task_loss_tot[task_name] = []
|
||||||
|
global_task_loss_tot[task_name].append(task_tot)
|
||||||
|
|
||||||
|
task_loss_map = {}
|
||||||
|
for task_name in sorted(list(global_task_loss_map.keys())):
|
||||||
|
avg_loss = 0.0
|
||||||
|
sum_token = sum(global_task_loss_tot[task_name])
|
||||||
|
for loss, token in zip(global_task_loss_map[task_name], global_task_loss_tot[task_name]):
|
||||||
|
avg_loss += loss * token / sum_token
|
||||||
|
task_loss_map[task_name] = avg_loss
|
||||||
|
|
||||||
|
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
||||||
|
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||||
|
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
||||||
|
avg_time = average_time.value
|
||||||
|
lsd.update_loss(iteration, task_loss_map)
|
||||||
|
|
||||||
|
for task_id in data["task_ids"]:
|
||||||
|
for task in task_id:
|
||||||
|
if task != -1:
|
||||||
|
if not data["task_names"][task] in hash:
|
||||||
|
hash[data["task_names"][task]] = 0
|
||||||
|
hash[data["task_names"][task]] += 1.0
|
||||||
|
total += 1.0
|
||||||
|
|
||||||
|
gathered_hash = allgather_objects(hash)
|
||||||
|
sum_total = sum(allgather_objects(total))
|
||||||
|
|
||||||
|
final_hash = defaultdict(int)
|
||||||
|
for local_hash in gathered_hash:
|
||||||
|
for task, num in local_hash.items():
|
||||||
|
final_hash[task] += num
|
||||||
|
|
||||||
|
# for i in final_hash:
|
||||||
|
# bmt.print_rank(i, final_hash[i] / sum_total)
|
||||||
|
# bmt.print_rank("=========================================")
|
||||||
|
|
||||||
|
train_info = {
|
||||||
|
"time": tim_usage["init"],
|
||||||
|
"iteration": iteration,
|
||||||
|
"loss": global_loss,
|
||||||
|
"lr": lr_scheduler.current_lr,
|
||||||
|
"lr_scale": int(optim_manager.loss_scale),
|
||||||
|
"time_usage": tim_usage,
|
||||||
|
"mem_usage": mem_usage,
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"token_max": local_total_rate,
|
||||||
|
"token_pass": global_token_pass,
|
||||||
|
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
||||||
|
"grad_norm": grad_norm.item(),
|
||||||
|
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||||
|
"num_gpus": global_world_size,
|
||||||
|
"task_loss": task_loss_map,
|
||||||
|
}
|
||||||
|
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||||
|
bmt.print_rank(
|
||||||
|
(
|
||||||
|
"| Iter: {:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | avg_time: {:.4f}; cur_time:{:.4f}={:.4f}+{:.4f} |"
|
||||||
|
+ " token/max: {:.4f} | mask/max: {:.4f} | grad_norm: {:.4f} | mem: {:.2f} |"
|
||||||
|
).format(
|
||||||
|
iteration,
|
||||||
|
global_loss,
|
||||||
|
lr_scheduler.current_lr,
|
||||||
|
int(optim_manager.loss_scale),
|
||||||
|
iter_time,
|
||||||
|
tim_usage["optim"] - tim_usage["init"],
|
||||||
|
tim_usage["backward"] - tim_usage["init"],
|
||||||
|
tim_usage["optim"] - tim_usage["backward"],
|
||||||
|
input_length.float().mean() / args.max_length / (args.batch_size if args.flash == "cuda" else 1),
|
||||||
|
(targets >= 0).sum(-1).float().mean()
|
||||||
|
/ args.max_length
|
||||||
|
/ (args.batch_size if args.flash == "cuda" else 1),
|
||||||
|
grad_norm,
|
||||||
|
max(mem_usage["forward"][1], mem_usage["backward"][1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"| "
|
||||||
|
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||||
|
+ " |"
|
||||||
|
)
|
||||||
|
|
||||||
|
if iteration % args.inspect_iters == 0:
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
train_info["model_inspect"] = model_inspect
|
||||||
|
|
||||||
|
if args.log_dir is not None and bmt.rank() == 0:
|
||||||
|
log_mgr.write(**train_info)
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||||
|
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||||
|
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||||
|
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||||
|
for task_name, loss in task_loss_map.items():
|
||||||
|
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||||
|
|
||||||
|
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||||
|
if args.save is not None and iteration % args.save_iters == 0:
|
||||||
|
exporter.export(model, dataloader, optimizer, iteration, args, final_save=False)
|
||||||
|
|
||||||
|
if iteration >= args.train_iters:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"train loop err: {e}")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
dataloader.close()
|
||||||
|
exporter.export(model, dataloader, optimizer, -1, args, final_save=False)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||||
|
bmt.print_rank("finish loading")
|
||||||
|
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,60 @@
|
||||||
|
#! /bin/bash
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks-per-node=8
|
||||||
|
#SBATCH --gres=gpu:8
|
||||||
|
#SBATCH --cpus-per-task=8
|
||||||
|
#SBATCH --mem=512GB
|
||||||
|
|
||||||
|
# use 8 GPU for example, pretrain may need 32 GPU
|
||||||
|
export MASTER_ADDR=`hostname`
|
||||||
|
export MASTER_PORT=12345
|
||||||
|
|
||||||
|
mkdir -p /home/${USERNAME}/logs/debug
|
||||||
|
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
|
||||||
|
|
||||||
|
cd apps/cpm9g
|
||||||
|
CONFIG_NAME="config/11b"
|
||||||
|
# --------------- 运行参数 ---------------
|
||||||
|
OPTS=""
|
||||||
|
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||||
|
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||||
|
OPTS+=" --batch-size 4"
|
||||||
|
OPTS+=" --train-iters 400000"
|
||||||
|
OPTS+=" --save-iters 250"
|
||||||
|
OPTS+=" --save-name cpm9g_checkpoint"
|
||||||
|
OPTS+=" --max-length 4096"
|
||||||
|
OPTS+=" --lr 1.5e-5"
|
||||||
|
OPTS+=" --inspect-iters 100"
|
||||||
|
OPTS+=" --warmup-iters 2000"
|
||||||
|
OPTS+=" --lr-decay-style noam"
|
||||||
|
OPTS+=" --weight-decay 0.1"
|
||||||
|
OPTS+=" --clip-grad 1.0"
|
||||||
|
OPTS+=" --loss-scale 1048576"
|
||||||
|
OPTS+=" --loss-scale-steps 32"
|
||||||
|
OPTS+=" --offload"
|
||||||
|
OPTS+=" --flash cuda"
|
||||||
|
# OPTS+=" --load-grad"
|
||||||
|
|
||||||
|
# --------------- 写文件路径 ---------------
|
||||||
|
## checkpoint
|
||||||
|
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
|
||||||
|
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
|
||||||
|
|
||||||
|
## logs,/local/logs 等价于 /data/logs(软链)
|
||||||
|
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
|
||||||
|
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
|
||||||
|
|
||||||
|
# --------------- 读文件路径 ---------------
|
||||||
|
OPTS+=" --dataset config/datasets.json"
|
||||||
|
OPTS+=" --load ${CHECKPOINT}"
|
||||||
|
OPTS+=" --start-step 1"
|
||||||
|
|
||||||
|
# --------------- 透传参数 ---------------
|
||||||
|
OPTS+=" $@"
|
||||||
|
|
||||||
|
# --------------- 最终指令 ---------------
|
||||||
|
|
||||||
|
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
|
||||||
|
echo "${CMD}"
|
||||||
|
|
||||||
|
$CMD
|
|
@ -0,0 +1,59 @@
|
||||||
|
#! /bin/bash
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks-per-node=8
|
||||||
|
#SBATCH --gres=gpu:8
|
||||||
|
#SBATCH --cpus-per-task=8
|
||||||
|
|
||||||
|
# use 8 GPU for example, pretrain may need 32 GPU
|
||||||
|
export MASTER_ADDR=`hostname`
|
||||||
|
export MASTER_PORT=12345
|
||||||
|
|
||||||
|
mkdir -p /home/${USERNAME}/logs/debug
|
||||||
|
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
|
||||||
|
|
||||||
|
cd apps/cpm9g
|
||||||
|
CONFIG_NAME="config/7b"
|
||||||
|
# --------------- 运行参数 ---------------
|
||||||
|
OPTS=""
|
||||||
|
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||||
|
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||||
|
OPTS+=" --batch-size 4"
|
||||||
|
OPTS+=" --train-iters 400000"
|
||||||
|
OPTS+=" --save-iters 250"
|
||||||
|
OPTS+=" --save-name cpm9g_checkpoint"
|
||||||
|
OPTS+=" --max-length 4096"
|
||||||
|
OPTS+=" --lr 1.5e-5"
|
||||||
|
OPTS+=" --inspect-iters 100"
|
||||||
|
OPTS+=" --warmup-iters 2000"
|
||||||
|
OPTS+=" --lr-decay-style noam"
|
||||||
|
OPTS+=" --weight-decay 0.1"
|
||||||
|
OPTS+=" --clip-grad 1.0"
|
||||||
|
OPTS+=" --loss-scale 1048576"
|
||||||
|
OPTS+=" --loss-scale-steps 32"
|
||||||
|
OPTS+=" --offload"
|
||||||
|
OPTS+=" --flash cuda"
|
||||||
|
# OPTS+=" --load-grad"
|
||||||
|
|
||||||
|
# --------------- 写文件路径 ---------------
|
||||||
|
## checkpoint
|
||||||
|
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
|
||||||
|
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
|
||||||
|
|
||||||
|
## logs,/local/logs 等价于 /data/logs(软链)
|
||||||
|
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
|
||||||
|
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
|
||||||
|
|
||||||
|
# --------------- 读文件路径 ---------------
|
||||||
|
OPTS+=" --dataset config/datasets.json"
|
||||||
|
OPTS+=" --load ${CHECKPOINT}"
|
||||||
|
OPTS+=" --start-step 1"
|
||||||
|
|
||||||
|
# --------------- 透传参数 ---------------
|
||||||
|
OPTS+=" $@"
|
||||||
|
|
||||||
|
# --------------- 最终指令 ---------------
|
||||||
|
|
||||||
|
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
|
||||||
|
echo "${CMD}"
|
||||||
|
|
||||||
|
$CMD
|
|
@ -0,0 +1,484 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The OpenBMB team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
||||||
|
from cpm.arguments import get_args
|
||||||
|
from cpm.cpm9g.models import CPM9G
|
||||||
|
from cpm.cpm9g.models import CPM9GConfig
|
||||||
|
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
||||||
|
from cpm.cpm9g.training_tasks import FinetuneDataset
|
||||||
|
from cpm.utils import allgather_objects
|
||||||
|
from cpm.utils import logger
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def get_tokenizer(args):
|
||||||
|
tokenizer = CPM9GTokenizer(path=args.vocab)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(args):
|
||||||
|
config = CPM9GConfig.from_json_file(args.model_config)
|
||||||
|
if args.flash == "none":
|
||||||
|
config.use_flash_attn = False
|
||||||
|
else:
|
||||||
|
config.use_flash_attn = True
|
||||||
|
if args.flash == "1d":
|
||||||
|
config.flash_attn_mask_shape = "1d"
|
||||||
|
else:
|
||||||
|
config.flash_attn_mask_shape = "2d"
|
||||||
|
if args.flash == "triton":
|
||||||
|
config.flash_impl = "triton"
|
||||||
|
elif args.flash == "cuda":
|
||||||
|
config.flash_impl = "cuda"
|
||||||
|
model = CPM9G(config)
|
||||||
|
if args.load is not None:
|
||||||
|
bmt.init_parameters(model)
|
||||||
|
bmt.synchronize()
|
||||||
|
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||||
|
bmt.load(model, args.load, strict=False)
|
||||||
|
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
else:
|
||||||
|
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||||
|
bmt.init_parameters(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer(args, model):
|
||||||
|
if args.offload:
|
||||||
|
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||||
|
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||||
|
if args.load is not None and args.load_grad:
|
||||||
|
start = time.time()
|
||||||
|
print(
|
||||||
|
sum(
|
||||||
|
[
|
||||||
|
1
|
||||||
|
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||||
|
else 0
|
||||||
|
for i in os.listdir(args.save)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
sum(
|
||||||
|
[
|
||||||
|
1
|
||||||
|
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||||
|
else 0
|
||||||
|
for i in os.listdir(args.save)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
== bmt.world_size()
|
||||||
|
):
|
||||||
|
file_name = os.path.join(
|
||||||
|
args.save,
|
||||||
|
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
|
||||||
|
)
|
||||||
|
print(file_name)
|
||||||
|
if os.path.exists(file_name):
|
||||||
|
print("start to load grad ckpt {}".format(file_name))
|
||||||
|
states = torch.load(file_name)
|
||||||
|
optimizer.load_state_dict(states)
|
||||||
|
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||||
|
r"""
|
||||||
|
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||||
|
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_lr_warmup(self, num_iter) -> float:
|
||||||
|
return self.start_lr * num_iter / self.warmup_iter
|
||||||
|
|
||||||
|
def get_lr_decay(self, num_iter) -> float:
|
||||||
|
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||||
|
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||||
|
|
||||||
|
|
||||||
|
def get_learning_rate_scheduler(args, optimizer):
|
||||||
|
if args.lr_decay_iters is None:
|
||||||
|
args.lr_decay_iters = args.train_iters
|
||||||
|
# lr_scheduler = bmt.lr_scheduler.Noam(
|
||||||
|
lr_scheduler = Cosine(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=args.warmup_iters,
|
||||||
|
end_iter=args.lr_decay_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
)
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model_and_optimizer(args):
|
||||||
|
start = time.time()
|
||||||
|
model = get_model(args)
|
||||||
|
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
optimizer = get_optimizer(args, model)
|
||||||
|
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
return tokenizer, model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
args = get_args(finetune=True)
|
||||||
|
|
||||||
|
# hack
|
||||||
|
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
|
||||||
|
bmt.init_distributed(checkpointing=False, seed=args.seed)
|
||||||
|
else:
|
||||||
|
bmt.init_distributed(seed=args.seed)
|
||||||
|
if args.save is not None:
|
||||||
|
os.makedirs(args.save, exist_ok=True)
|
||||||
|
# if args.load is not None:
|
||||||
|
# if args.start_step == 0:
|
||||||
|
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def see_memory(detail=False):
|
||||||
|
if detail:
|
||||||
|
res = torch.cuda.memory_summary()
|
||||||
|
else:
|
||||||
|
res = (
|
||||||
|
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||||
|
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||||
|
)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def add_mem_time(info, mem_usage, tim_usage):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
mem_usage[info] = see_memory()
|
||||||
|
tim_usage[info] = time.time()
|
||||||
|
return mem_usage, tim_usage
|
||||||
|
|
||||||
|
|
||||||
|
class LossSpikeDetector:
|
||||||
|
def __init__(self, log_path: str) -> None:
|
||||||
|
self._last_loss: Dict[str, float] = {}
|
||||||
|
self._last_data: List[Any] = [None]
|
||||||
|
self._log_path = log_path
|
||||||
|
|
||||||
|
def update_data(self, data: Any):
|
||||||
|
self._last_data.append(data)
|
||||||
|
if len(self._last_data) > 2:
|
||||||
|
self._last_data = self._last_data[-2:]
|
||||||
|
|
||||||
|
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
||||||
|
loss_spike_result = []
|
||||||
|
for task, loss in loss_map.items():
|
||||||
|
if task in self._last_loss:
|
||||||
|
if loss > self._last_loss[task] * 3:
|
||||||
|
# loss spike!
|
||||||
|
loss_spike_result.append(
|
||||||
|
{
|
||||||
|
"prev": self._last_loss[task],
|
||||||
|
"curr": loss,
|
||||||
|
"task": task,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._last_loss[task] = float(loss)
|
||||||
|
if len(loss_spike_result) > 0:
|
||||||
|
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
||||||
|
|
||||||
|
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
||||||
|
return
|
||||||
|
with open(self._log_path, "a", encoding="utf-8") as fp:
|
||||||
|
fp.write("=" * 20)
|
||||||
|
fp.write("\nloss spike at {}\n".format(iteration))
|
||||||
|
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
||||||
|
fp.write("data: \n")
|
||||||
|
for d in data:
|
||||||
|
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
||||||
|
fp.write("\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
def finetune(
|
||||||
|
args,
|
||||||
|
bin_file: str,
|
||||||
|
tokenizer: CPM9GTokenizer,
|
||||||
|
model: CPM9G,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||||
|
):
|
||||||
|
average_time = bmt.utils.AverageRecorder()
|
||||||
|
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
||||||
|
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
|
||||||
|
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
import distutils.version # noqa: F401
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
if not os.path.exists(args.tensorboard):
|
||||||
|
os.makedirs(args.tensorboard)
|
||||||
|
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||||
|
|
||||||
|
global_token_pass = 0.0
|
||||||
|
global_world_size = bmt.world_size()
|
||||||
|
|
||||||
|
for epoch in range(args.epoch):
|
||||||
|
epoch = epoch + 1
|
||||||
|
last_data = None
|
||||||
|
dataloader = FinetuneDataset(
|
||||||
|
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||||
|
)
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
for iteration, data in enumerate(dataloader):
|
||||||
|
iteration = iteration + 1
|
||||||
|
skip_this_batch = False
|
||||||
|
if data is None:
|
||||||
|
if last_data is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Dataset is too small, please use a smaller batch size or sequence length!"
|
||||||
|
)
|
||||||
|
data = last_data # use last data
|
||||||
|
skip_this_batch = True
|
||||||
|
else:
|
||||||
|
last_data = data
|
||||||
|
|
||||||
|
assert data["inputs"].shape[0] == args.batch_size
|
||||||
|
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||||
|
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
||||||
|
targets = torch.from_numpy(data["target"]).cuda().long()
|
||||||
|
# bmt.print_rank(input_ids[0].tolist())
|
||||||
|
# bmt.print_rank(targets[0].tolist())
|
||||||
|
# bmt.print_rank(data["spans"].tolist())
|
||||||
|
# bmt.print_rank(tokenizer.decode(input_ids[0]))
|
||||||
|
# bmt.print_rank(tokenizer.path)
|
||||||
|
# bmt.synchronize()
|
||||||
|
# exit()
|
||||||
|
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
||||||
|
task_names = data["task_names"]
|
||||||
|
if args.flash == "cuda":
|
||||||
|
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
||||||
|
max_seqlen = data["max_seqlen"]
|
||||||
|
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
||||||
|
else:
|
||||||
|
input_context = torch.zeros_like(input_ids).cuda().bool()
|
||||||
|
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
# optim_manager.zero_grad()
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
mem_usage = {}
|
||||||
|
tim_usage = {}
|
||||||
|
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
if args.flash == "cuda":
|
||||||
|
logits, _ = model(
|
||||||
|
input_ids,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, _ = model(
|
||||||
|
input_ids,
|
||||||
|
input_length,
|
||||||
|
input_context,
|
||||||
|
input_span,
|
||||||
|
)
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
||||||
|
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||||
|
if skip_this_batch:
|
||||||
|
loss = loss * 0
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
optim_manager.backward(loss)
|
||||||
|
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
if iteration % args.gradient_accumulation_steps == 0:
|
||||||
|
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||||
|
optim_manager.step()
|
||||||
|
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
else:
|
||||||
|
grad_norm = None
|
||||||
|
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ==========
|
||||||
|
iter_time = tim_usage["optim"] - tim_usage["init"]
|
||||||
|
average_time.record(iter_time)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
task_num = len(task_names)
|
||||||
|
targets_tmp = targets.expand(task_num, -1, -1)
|
||||||
|
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
|
||||||
|
targets_tmp = torch.where(
|
||||||
|
task_ids == task,
|
||||||
|
targets_tmp,
|
||||||
|
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
task_loss_map: Dict[str, float] = {}
|
||||||
|
if not skip_this_batch:
|
||||||
|
for i in range(task_num):
|
||||||
|
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
|
||||||
|
task_loss_map[task_names[i]] = task_loss.item()
|
||||||
|
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
||||||
|
|
||||||
|
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
||||||
|
for local_task_loss_map in gatherd_task_loss_map:
|
||||||
|
for task_name, task_loss in local_task_loss_map.items():
|
||||||
|
if task_name not in global_task_loss_map:
|
||||||
|
global_task_loss_map[task_name] = []
|
||||||
|
global_task_loss_map[task_name].append(task_loss)
|
||||||
|
|
||||||
|
task_loss_map = {}
|
||||||
|
for task_name in sorted(list(global_task_loss_map.keys())):
|
||||||
|
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
|
||||||
|
task_loss_map[task_name] = avg_loss
|
||||||
|
|
||||||
|
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
||||||
|
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||||
|
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
||||||
|
avg_time = average_time.value
|
||||||
|
|
||||||
|
train_info = {
|
||||||
|
"time": tim_usage["init"],
|
||||||
|
"epoch": epoch,
|
||||||
|
"iteration": iteration,
|
||||||
|
"loss": task_loss_map[args.task_name],
|
||||||
|
"lr": lr_scheduler.current_lr,
|
||||||
|
"lr_scale": int(optim_manager.loss_scale),
|
||||||
|
"time_usage": tim_usage,
|
||||||
|
"mem_usage": mem_usage,
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"token_max": local_total_rate,
|
||||||
|
"token_pass": global_token_pass,
|
||||||
|
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
||||||
|
"grad_norm": 0 if grad_norm == None else grad_norm.item(),
|
||||||
|
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||||
|
"num_gpus": global_world_size,
|
||||||
|
"task_loss": task_loss_map,
|
||||||
|
}
|
||||||
|
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||||
|
bmt.print_rank(
|
||||||
|
(
|
||||||
|
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
|
||||||
|
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
|
||||||
|
).format(
|
||||||
|
epoch,
|
||||||
|
iteration,
|
||||||
|
args.train_iters,
|
||||||
|
task_loss_map[args.task_name],
|
||||||
|
lr_scheduler.current_lr,
|
||||||
|
int(optim_manager.loss_scale),
|
||||||
|
avg_time,
|
||||||
|
input_length.float().mean() / args.max_length,
|
||||||
|
(targets >= 0).sum(-1).float().mean() / args.max_length,
|
||||||
|
0 if grad_norm == None else grad_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"| "
|
||||||
|
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||||
|
+ " |"
|
||||||
|
)
|
||||||
|
if iteration % args.inspect_iters == 0:
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
train_info["model_inspect"] = model_inspect
|
||||||
|
|
||||||
|
# save_folder_name = f"{args.save}{epoch}{iteration}"
|
||||||
|
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
|
||||||
|
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||||
|
# if bmt.rank() == 0:
|
||||||
|
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||||
|
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
|
||||||
|
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
|
||||||
|
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
|
||||||
|
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", 0 if grad_norm == None else grad_norm.item(), iteration)
|
||||||
|
|
||||||
|
# if iteration % 10 == 0:
|
||||||
|
# save_folder_name = f"{args.save}{epoch}"
|
||||||
|
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
||||||
|
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||||
|
# bmt.save(model, model_fname)
|
||||||
|
# if bmt.rank() == 0:
|
||||||
|
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||||
|
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||||
|
|
||||||
|
save_folder_name = f"{args.save}-epoch-{epoch}"
|
||||||
|
model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
||||||
|
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||||
|
bmt.save(model, model_fname)
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||||
|
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
# To Be Specified
|
||||||
|
bin_file = args.dataset
|
||||||
|
bmt.print_rank(f"dataset: {bin_file}")
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
dataloader = FinetuneDataset(
|
||||||
|
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||||
|
)
|
||||||
|
bmt.print_rank(f"#batch: {len(dataloader)}")
|
||||||
|
total_steps = len(dataloader) * args.epoch
|
||||||
|
|
||||||
|
setattr(args, 'train_iters', int(total_steps))
|
||||||
|
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
|
||||||
|
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
|
||||||
|
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||||
|
|
||||||
|
bmt.print_rank("finish loading")
|
||||||
|
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,55 @@
|
||||||
|
#! /bin/bash
|
||||||
|
#SBATCH --partition=gpu3-1
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks-per-node=8
|
||||||
|
#SBATCH --gres=gpu:8
|
||||||
|
#SBATCH --cpus-per-task=8
|
||||||
|
|
||||||
|
export MASTER_ADDR=g3002
|
||||||
|
export MASTER_PORT=12345
|
||||||
|
|
||||||
|
CPM_PATH="/data/groups/QY_LLM_Core/projects/202311-release/Models/11B-Chat/9G-Train"
|
||||||
|
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
|
||||||
|
EXP_PATH=./exp
|
||||||
|
mkdir -p $EXP_PATH
|
||||||
|
MODEL_NAME="cpm9g-11b-sft"
|
||||||
|
|
||||||
|
OPTS=""
|
||||||
|
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||||
|
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||||
|
|
||||||
|
OPTS+=" --train-iters 10000"
|
||||||
|
OPTS+=" --inspect-iters 200"
|
||||||
|
OPTS+=" --warmup-iters 500"
|
||||||
|
|
||||||
|
OPTS+=" --lr-decay-style cosine"
|
||||||
|
OPTS+=" --weight-decay 0.1"
|
||||||
|
OPTS+=" --clip-grad 1.0"
|
||||||
|
OPTS+=" --loss-scale 1048576"
|
||||||
|
OPTS+=" --max-loss-scale 33554432"
|
||||||
|
OPTS+=" --min-loss-scale 1"
|
||||||
|
OPTS+=" --loss-scale-steps 32"
|
||||||
|
|
||||||
|
OPTS+=" --offload"
|
||||||
|
OPTS+=" --batch-size 4"
|
||||||
|
OPTS+=" --max-length 4096"
|
||||||
|
OPTS+=" --lr 2e-5"
|
||||||
|
OPTS+=" --start-step 0"
|
||||||
|
OPTS+=" --epoch 8"
|
||||||
|
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
|
||||||
|
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
|
||||||
|
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
|
||||||
|
OPTS+=" --save ${EXP_PATH}/checkpoints"
|
||||||
|
OPTS+=" --save-name ${MODEL_NAME}"
|
||||||
|
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
|
||||||
|
# OPTS+=" --flash triton"
|
||||||
|
# OPTS+=" --flash cuda"
|
||||||
|
# OPTS+=" --load-grad"
|
||||||
|
|
||||||
|
OPTS+=" $@"
|
||||||
|
|
||||||
|
|
||||||
|
CMD="torchrun --nnodes=4 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g.py ${OPTS}"
|
||||||
|
|
||||||
|
echo "${CMD}"
|
||||||
|
$CMD
|
|
@ -0,0 +1,515 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The OpenBMB team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
|
||||||
|
from cpm.arguments import get_args
|
||||||
|
from cpm.cpm9g.models import CPM9G
|
||||||
|
from cpm.cpm9g.models import CPM9GConfig
|
||||||
|
from cpm.cpm9g.tokenizers import CPM9GTokenizer
|
||||||
|
from cpm.cpm9g.training_tasks import FinetuneDataset
|
||||||
|
from cpm.utils import allgather_objects
|
||||||
|
from cpm.utils import logger
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import opendelta as od
|
||||||
|
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
|
||||||
|
from opendelta.utils.inspect import inspect_optimizer_statistics
|
||||||
|
from bigmodelvis import Visualization
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(args):
|
||||||
|
tokenizer = CPM9GTokenizer(path=args.vocab)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(args):
|
||||||
|
config = CPM9GConfig.from_json_file(args.model_config)
|
||||||
|
if args.flash == "none":
|
||||||
|
config.use_flash_attn = False
|
||||||
|
else:
|
||||||
|
config.use_flash_attn = True
|
||||||
|
if args.flash == "1d":
|
||||||
|
config.flash_attn_mask_shape = "1d"
|
||||||
|
else:
|
||||||
|
config.flash_attn_mask_shape = "2d"
|
||||||
|
if args.flash == "triton":
|
||||||
|
config.flash_impl = "triton"
|
||||||
|
elif args.flash == "cuda":
|
||||||
|
config.flash_impl = "cuda"
|
||||||
|
model = CPM9G(config)
|
||||||
|
if args.load is not None:
|
||||||
|
bmt.init_parameters(model)
|
||||||
|
bmt.synchronize()
|
||||||
|
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||||
|
bmt.load(model, args.load, strict=False)
|
||||||
|
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
else:
|
||||||
|
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||||
|
bmt.init_parameters(model)
|
||||||
|
|
||||||
|
if args.delta_type != None:
|
||||||
|
from bigmodelvis import Visualization #0813
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
Visualization(model).structure_graph()
|
||||||
|
print("finetuned layers: ")
|
||||||
|
print(args.lora_layer)
|
||||||
|
|
||||||
|
if args.delta_type == "lora":
|
||||||
|
delta_model = LoraModel(backbone_model=model, modified_modules=args.lora_layer, backend='bmt', lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
|
||||||
|
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
print("Before freeze: ")
|
||||||
|
delta_model.log()
|
||||||
|
|
||||||
|
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
|
||||||
|
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
print("After freeze: ")
|
||||||
|
delta_model.log()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer(args, model):
|
||||||
|
if args.offload:
|
||||||
|
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||||
|
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||||
|
if args.load is not None and args.load_grad:
|
||||||
|
start = time.time()
|
||||||
|
print(
|
||||||
|
sum(
|
||||||
|
[
|
||||||
|
1
|
||||||
|
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||||
|
else 0
|
||||||
|
for i in os.listdir(args.save)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
sum(
|
||||||
|
[
|
||||||
|
1
|
||||||
|
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||||
|
else 0
|
||||||
|
for i in os.listdir(args.save)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
== bmt.world_size()
|
||||||
|
):
|
||||||
|
file_name = os.path.join(
|
||||||
|
args.save,
|
||||||
|
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
|
||||||
|
)
|
||||||
|
print(file_name)
|
||||||
|
if os.path.exists(file_name):
|
||||||
|
print("start to load grad ckpt {}".format(file_name))
|
||||||
|
states = torch.load(file_name)
|
||||||
|
optimizer.load_state_dict(states)
|
||||||
|
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||||
|
r"""
|
||||||
|
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||||
|
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_lr_warmup(self, num_iter) -> float:
|
||||||
|
return self.start_lr * num_iter / self.warmup_iter
|
||||||
|
|
||||||
|
def get_lr_decay(self, num_iter) -> float:
|
||||||
|
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||||
|
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||||
|
|
||||||
|
|
||||||
|
def get_learning_rate_scheduler(args, optimizer):
|
||||||
|
if args.lr_decay_iters is None:
|
||||||
|
args.lr_decay_iters = args.train_iters
|
||||||
|
# lr_scheduler = bmt.lr_scheduler.Noam(
|
||||||
|
lr_scheduler = Cosine(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=args.warmup_iters,
|
||||||
|
end_iter=args.lr_decay_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
)
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model_and_optimizer(args):
|
||||||
|
start = time.time()
|
||||||
|
model = get_model(args)
|
||||||
|
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
optimizer = get_optimizer(args, model)
|
||||||
|
|
||||||
|
if args.delta_type != None:
|
||||||
|
inspect_optimizer_statistics(optimizer)
|
||||||
|
|
||||||
|
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||||
|
bmt.synchronize()
|
||||||
|
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||||
|
|
||||||
|
return tokenizer, model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
args = get_args(finetune=True)
|
||||||
|
|
||||||
|
# hack
|
||||||
|
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
|
||||||
|
bmt.init_distributed(checkpointing=False, seed=args.seed)
|
||||||
|
else:
|
||||||
|
bmt.init_distributed(seed=args.seed)
|
||||||
|
if args.save is not None:
|
||||||
|
os.makedirs(args.save, exist_ok=True)
|
||||||
|
# if args.load is not None:
|
||||||
|
# if args.start_step == 0:
|
||||||
|
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def see_memory(detail=False):
|
||||||
|
if detail:
|
||||||
|
res = torch.cuda.memory_summary()
|
||||||
|
else:
|
||||||
|
res = (
|
||||||
|
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||||
|
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||||
|
)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def add_mem_time(info, mem_usage, tim_usage):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
mem_usage[info] = see_memory()
|
||||||
|
tim_usage[info] = time.time()
|
||||||
|
return mem_usage, tim_usage
|
||||||
|
|
||||||
|
|
||||||
|
class LossSpikeDetector:
|
||||||
|
def __init__(self, log_path: str) -> None:
|
||||||
|
self._last_loss: Dict[str, float] = {}
|
||||||
|
self._last_data: List[Any] = [None]
|
||||||
|
self._log_path = log_path
|
||||||
|
|
||||||
|
def update_data(self, data: Any):
|
||||||
|
self._last_data.append(data)
|
||||||
|
if len(self._last_data) > 2:
|
||||||
|
self._last_data = self._last_data[-2:]
|
||||||
|
|
||||||
|
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
|
||||||
|
loss_spike_result = []
|
||||||
|
for task, loss in loss_map.items():
|
||||||
|
if task in self._last_loss:
|
||||||
|
if loss > self._last_loss[task] * 3:
|
||||||
|
# loss spike!
|
||||||
|
loss_spike_result.append(
|
||||||
|
{
|
||||||
|
"prev": self._last_loss[task],
|
||||||
|
"curr": loss,
|
||||||
|
"task": task,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._last_loss[task] = float(loss)
|
||||||
|
if len(loss_spike_result) > 0:
|
||||||
|
self._write_log(iteration, self._last_data[-1], loss_spike_result)
|
||||||
|
|
||||||
|
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
|
||||||
|
return
|
||||||
|
with open(self._log_path, "a", encoding="utf-8") as fp:
|
||||||
|
fp.write("=" * 20)
|
||||||
|
fp.write("\nloss spike at {}\n".format(iteration))
|
||||||
|
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
|
||||||
|
fp.write("data: \n")
|
||||||
|
for d in data:
|
||||||
|
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
|
||||||
|
fp.write("\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
def finetune(
|
||||||
|
args,
|
||||||
|
bin_file: str,
|
||||||
|
tokenizer: CPM9GTokenizer,
|
||||||
|
model: CPM9G,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||||
|
):
|
||||||
|
average_time = bmt.utils.AverageRecorder()
|
||||||
|
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
|
||||||
|
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
|
||||||
|
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
import distutils.version # noqa: F401
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
if not os.path.exists(args.tensorboard):
|
||||||
|
os.makedirs(args.tensorboard)
|
||||||
|
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||||
|
|
||||||
|
global_token_pass = 0.0
|
||||||
|
global_world_size = bmt.world_size()
|
||||||
|
|
||||||
|
for epoch in range(args.epoch):
|
||||||
|
epoch = epoch + 1
|
||||||
|
last_data = None
|
||||||
|
dataloader = FinetuneDataset(
|
||||||
|
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||||
|
)
|
||||||
|
for iteration, data in enumerate(dataloader):
|
||||||
|
iteration = iteration + 1
|
||||||
|
skip_this_batch = False
|
||||||
|
if data is None:
|
||||||
|
if last_data is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Dataset is too small, please use a smaller batch size or sequence length!"
|
||||||
|
)
|
||||||
|
data = last_data # use last data
|
||||||
|
skip_this_batch = True
|
||||||
|
else:
|
||||||
|
last_data = data
|
||||||
|
|
||||||
|
assert data["inputs"].shape[0] == args.batch_size
|
||||||
|
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
|
||||||
|
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
|
||||||
|
targets = torch.from_numpy(data["target"]).cuda().long()
|
||||||
|
# bmt.print_rank(input_ids[0].tolist())
|
||||||
|
# bmt.print_rank(targets[0].tolist())
|
||||||
|
# bmt.print_rank(data["spans"].tolist())
|
||||||
|
# bmt.print_rank(tokenizer.decode(input_ids[0]))
|
||||||
|
# bmt.print_rank(tokenizer.path)
|
||||||
|
# bmt.synchronize()
|
||||||
|
# exit()
|
||||||
|
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
|
||||||
|
task_names = data["task_names"]
|
||||||
|
if args.flash == "cuda":
|
||||||
|
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
|
||||||
|
max_seqlen = data["max_seqlen"]
|
||||||
|
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
|
||||||
|
else:
|
||||||
|
input_context = torch.zeros_like(input_ids).cuda().bool()
|
||||||
|
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
mem_usage = {}
|
||||||
|
tim_usage = {}
|
||||||
|
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
if args.flash == "cuda":
|
||||||
|
logits, _ = model(
|
||||||
|
input_ids,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, _ = model(
|
||||||
|
input_ids,
|
||||||
|
input_length,
|
||||||
|
input_context,
|
||||||
|
input_span,
|
||||||
|
)
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
|
||||||
|
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||||
|
if skip_this_batch:
|
||||||
|
loss = loss * 0
|
||||||
|
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
optim_manager.backward(loss)
|
||||||
|
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ===========
|
||||||
|
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||||
|
optim_manager.step()
|
||||||
|
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||||
|
|
||||||
|
# ==========
|
||||||
|
iter_time = tim_usage["optim"] - tim_usage["init"]
|
||||||
|
average_time.record(iter_time)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
task_num = len(task_names)
|
||||||
|
targets_tmp = targets.expand(task_num, -1, -1)
|
||||||
|
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
|
||||||
|
targets_tmp = torch.where(
|
||||||
|
task_ids == task,
|
||||||
|
targets_tmp,
|
||||||
|
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
task_loss_map: Dict[str, float] = {}
|
||||||
|
if not skip_this_batch:
|
||||||
|
for i in range(task_num):
|
||||||
|
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
|
||||||
|
task_loss_map[task_names[i]] = task_loss.item()
|
||||||
|
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
|
||||||
|
|
||||||
|
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
|
||||||
|
for local_task_loss_map in gatherd_task_loss_map:
|
||||||
|
for task_name, task_loss in local_task_loss_map.items():
|
||||||
|
if task_name not in global_task_loss_map:
|
||||||
|
global_task_loss_map[task_name] = []
|
||||||
|
global_task_loss_map[task_name].append(task_loss)
|
||||||
|
|
||||||
|
task_loss_map = {}
|
||||||
|
for task_name in sorted(list(global_task_loss_map.keys())):
|
||||||
|
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
|
||||||
|
task_loss_map[task_name] = avg_loss
|
||||||
|
|
||||||
|
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
|
||||||
|
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||||
|
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
|
||||||
|
avg_time = average_time.value
|
||||||
|
|
||||||
|
train_info = {
|
||||||
|
"time": tim_usage["init"],
|
||||||
|
"epoch": epoch,
|
||||||
|
"iteration": iteration,
|
||||||
|
"loss": task_loss_map[args.task_name],
|
||||||
|
"lr": lr_scheduler.current_lr,
|
||||||
|
"lr_scale": int(optim_manager.loss_scale),
|
||||||
|
"time_usage": tim_usage,
|
||||||
|
"mem_usage": mem_usage,
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"token_max": local_total_rate,
|
||||||
|
"token_pass": global_token_pass,
|
||||||
|
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
||||||
|
"grad_norm": grad_norm.item(),
|
||||||
|
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||||
|
"num_gpus": global_world_size,
|
||||||
|
"task_loss": task_loss_map,
|
||||||
|
}
|
||||||
|
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||||
|
bmt.print_rank(
|
||||||
|
(
|
||||||
|
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
|
||||||
|
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
|
||||||
|
).format(
|
||||||
|
epoch,
|
||||||
|
iteration,
|
||||||
|
args.train_iters,
|
||||||
|
task_loss_map[args.task_name],
|
||||||
|
lr_scheduler.current_lr,
|
||||||
|
int(optim_manager.loss_scale),
|
||||||
|
avg_time,
|
||||||
|
input_length.float().mean() / args.max_length,
|
||||||
|
(targets >= 0).sum(-1).float().mean() / args.max_length,
|
||||||
|
grad_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
bmt.print_rank(
|
||||||
|
"| "
|
||||||
|
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||||
|
+ " |"
|
||||||
|
)
|
||||||
|
if iteration % args.inspect_iters == 0:
|
||||||
|
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||||
|
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||||
|
train_info["model_inspect"] = model_inspect
|
||||||
|
|
||||||
|
# save_folder_name = f"{args.save}{epoch}{iteration}"
|
||||||
|
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
|
||||||
|
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||||
|
# if bmt.rank() == 0:
|
||||||
|
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||||
|
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||||
|
|
||||||
|
if args.tensorboard is not None and bmt.rank() == 0:
|
||||||
|
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
|
||||||
|
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
|
||||||
|
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
|
||||||
|
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", grad_norm.item(), iteration)
|
||||||
|
|
||||||
|
# if iteration % 10 == 0:
|
||||||
|
# save_folder_name = f"{args.save}{epoch}"
|
||||||
|
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
|
||||||
|
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||||
|
# bmt.save(model, model_fname)
|
||||||
|
# if bmt.rank() == 0:
|
||||||
|
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||||
|
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||||
|
|
||||||
|
save_folder_name = f"{args.save}/epoch-{epoch}"
|
||||||
|
model_fname = os.path.join(save_folder_name, f"{args.save_name}.pt")
|
||||||
|
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
if args.delta_type == None or args.save_origin_model == True :
|
||||||
|
print("saving base model...")
|
||||||
|
bmt.save(model, model_fname)
|
||||||
|
if args.delta_type != None and bmt.rank() == 0:
|
||||||
|
print("saving delta model...")
|
||||||
|
torch.save(state_dict, os.path.join(save_folder_name, f"{args.save_name}-delta.pt"))
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
|
||||||
|
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
# To Be Specified
|
||||||
|
bin_file = args.dataset
|
||||||
|
bmt.print_rank(f"dataset: {bin_file}")
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
dataloader = FinetuneDataset(
|
||||||
|
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
|
||||||
|
)
|
||||||
|
bmt.print_rank(f"#batch: {len(dataloader)}")
|
||||||
|
total_steps = len(dataloader) * args.epoch
|
||||||
|
|
||||||
|
setattr(args, 'train_iters', int(total_steps))
|
||||||
|
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
|
||||||
|
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
|
||||||
|
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||||
|
|
||||||
|
bmt.print_rank("finish loading")
|
||||||
|
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,63 @@
|
||||||
|
#! /bin/bash
|
||||||
|
#SBATCH --partition=gpu3-1
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks-per-node=8
|
||||||
|
#SBATCH --gres=gpu:8
|
||||||
|
#SBATCH --cpus-per-task=8
|
||||||
|
|
||||||
|
export MASTER_ADDR=`hostname`
|
||||||
|
export MASTER_PORT=12345
|
||||||
|
echo $MASTER_ADDR
|
||||||
|
|
||||||
|
CPM_PATH="./9G-Train"
|
||||||
|
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
|
||||||
|
EXP_PATH=./exp
|
||||||
|
mkdir -p $EXP_PATH
|
||||||
|
MODEL_NAME="cpm9g-11b-sft"
|
||||||
|
|
||||||
|
OPTS=""
|
||||||
|
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
|
||||||
|
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
|
||||||
|
|
||||||
|
OPTS+=" --train-iters 10000"
|
||||||
|
OPTS+=" --inspect-iters 200"
|
||||||
|
OPTS+=" --warmup-iters 500"
|
||||||
|
|
||||||
|
OPTS+=" --lr-decay-style cosine"
|
||||||
|
OPTS+=" --weight-decay 0.1"
|
||||||
|
OPTS+=" --clip-grad 1.0"
|
||||||
|
OPTS+=" --loss-scale 1048576"
|
||||||
|
OPTS+=" --max-loss-scale 33554432"
|
||||||
|
OPTS+=" --min-loss-scale 1"
|
||||||
|
OPTS+=" --loss-scale-steps 32"
|
||||||
|
|
||||||
|
OPTS+=" --offload"
|
||||||
|
OPTS+=" --batch-size 4"
|
||||||
|
OPTS+=" --max-length 4096"
|
||||||
|
OPTS+=" --lr 2e-5"
|
||||||
|
OPTS+=" --start-step 0"
|
||||||
|
OPTS+=" --epoch 8"
|
||||||
|
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
|
||||||
|
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
|
||||||
|
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
|
||||||
|
OPTS+=" --save ${EXP_PATH}/checkpoints"
|
||||||
|
OPTS+=" --save-name ${MODEL_NAME}"
|
||||||
|
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
|
||||||
|
# OPTS+=" --flash triton"
|
||||||
|
# OPTS+=" --flash cuda"
|
||||||
|
# OPTS+=" --load-grad"
|
||||||
|
|
||||||
|
OPTS+=" --delta-tuning" #开启delta-tuning
|
||||||
|
OPTS+=" --delta-type lora" #目前仅支持lora
|
||||||
|
OPTS+=" --lora-r 64" #lora矩阵的维度,默认为8
|
||||||
|
OPTS+=" --lora-dropout 0.05" #默认为0
|
||||||
|
OPTS+=" --lora-alpha 64" #lora对原模型的影响比例,默认为8
|
||||||
|
OPTS+=" --lora-layer project_q project_v project_k w_0 w_1 w_out" #参与lora的线性层,默认为project_q,project_k
|
||||||
|
OPTS+=" --save-origin-model" #是否在每个epoch储存基座模型
|
||||||
|
|
||||||
|
OPTS+=" $@"
|
||||||
|
|
||||||
|
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g_delta.py ${OPTS}"
|
||||||
|
|
||||||
|
echo "${CMD}"
|
||||||
|
$CMD
|
|
@ -1,18 +1,3 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,8 +7,6 @@ def add_model_config_args(parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group("model", "model configuration")
|
group = parser.add_argument_group("model", "model configuration")
|
||||||
group.add_argument("--model-config", type=str, help="model configuration file")
|
group.add_argument("--model-config", type=str, help="model configuration file")
|
||||||
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
|
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
|
||||||
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
|
|
||||||
# group.add_argument("--qk_norm", action="store_true", default=False, help="qk layernorm")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,13 +31,6 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
help="Load the gradient states",
|
help="Load the gradient states",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--grad-ckpt-num",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="grad file num (only work when --load-grad from files less than world-size )",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--load-start-step",
|
"--load-start-step",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
@ -90,9 +66,7 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
|
|
||||||
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
|
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
|
||||||
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
|
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
|
||||||
group.add_argument("--num-micro-batches", type=int, default=16)
|
|
||||||
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
|
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
|
||||||
group.add_argument("--grad-accum", type=int, default=1, help="gradient accum steps")
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--train-iters",
|
"--train-iters",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -100,14 +74,11 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
help="total number of iterations to train over all training runs",
|
help="total number of iterations to train over all training runs",
|
||||||
)
|
)
|
||||||
group.add_argument("--max-length", type=int, default=512, help="max length of input")
|
group.add_argument("--max-length", type=int, default=512, help="max length of input")
|
||||||
group.add_argument("--min-length", type=int, default=None, help="only for speed test")
|
|
||||||
|
|
||||||
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
|
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
|
||||||
|
|
||||||
# Learning rate.
|
# Learning rate.
|
||||||
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
|
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
|
||||||
group.add_argument("--lr_scheduler", type=str, default="cosine", help=" learning rate scheduler")
|
|
||||||
|
|
||||||
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
|
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
|
||||||
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
|
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
|
||||||
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
|
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
|
||||||
|
@ -121,85 +92,21 @@ def add_training_args(parser: argparse.ArgumentParser):
|
||||||
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--drop-iters",
|
"--lr-decay-style",
|
||||||
type=float,
|
type=str,
|
||||||
default=0.01,
|
default="noam",
|
||||||
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
choices=["constant", "linear", "cosine", "exponential", "noam"],
|
||||||
|
help="learning rate decay function",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
|
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
|
||||||
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
|
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
|
||||||
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
|
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
|
||||||
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
|
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
|
||||||
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
|
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
|
||||||
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
|
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
|
||||||
group.add_argument("--use-jfs-data", action="store_true", help="whether we use juicefs dataset")
|
group.add_argument("--tp", default=1, type=int, help="whether we use tensor parallelism")
|
||||||
group.add_argument("--tp-size", default=1, type=int)
|
|
||||||
group.add_argument("--pp-size", default=1, type=int)
|
|
||||||
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
|
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
|
||||||
group.add_argument("--dataloader_num_threads", default=3, type=int, help="Only useful in indexed dataest.")
|
group.add_argument("--gradient-accumulation-steps", type=int, default=1, help="gradient accumulation steps")
|
||||||
group.add_argument("--dataloader_prefetch", default=200, type=int, help="Only useful in indexed dataest.")
|
|
||||||
group.add_argument("--dataloader_num_workers", default=4, type=int, help="Only useful in indexed dataest.")
|
|
||||||
group.add_argument("--dataloader_prefetch_factor", default=50, type=int, help="Only useful in indexed dataest.")
|
|
||||||
group.add_argument(
|
|
||||||
"--dataloader",
|
|
||||||
default="indexed",
|
|
||||||
type=str,
|
|
||||||
help="dataloader type, 'indexed' for indexed dataset, 'normal' for normal dataset",
|
|
||||||
)
|
|
||||||
group.add_argument("--stop_when_end", default=0, type=int, help="Whether to stop training when we reach end_iter")
|
|
||||||
group.add_argument(
|
|
||||||
"--data_len_threshold",
|
|
||||||
default=512,
|
|
||||||
type=int,
|
|
||||||
help="If the average length of a sequence is less than this int, mean the sample is biased. ",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--only_run_dataloader", default=0, type=int, help="Whether to only run dataloader to check data. "
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--only_load_model", default=0, type=int, help="Whether to only load a model ckpt, without anything else."
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--load_dataloader_ckpt", default=1, type=int, help="Whether to only load a model ckpt, without anything else."
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--resume_no_optimze",
|
|
||||||
default=0,
|
|
||||||
type=int,
|
|
||||||
help="The number of steps that does not add optimization after resume",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--parallel_load_datastate",
|
|
||||||
default=256,
|
|
||||||
type=int,
|
|
||||||
help="The number of parallel workers to load dataset state",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--async_save",
|
|
||||||
action="store_true",
|
|
||||||
help="whether to save artifacts asynchronously",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--drop_begin",
|
|
||||||
default=-1,
|
|
||||||
type=int,
|
|
||||||
help="The number of steps that starts to drop lr"
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--drop_rate",
|
|
||||||
default=0.5,
|
|
||||||
type=float,
|
|
||||||
help="The number rate"
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--use_checkpoint",
|
|
||||||
default=1,
|
|
||||||
type=int,
|
|
||||||
help="Whether to use checkpointing."
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -226,17 +133,6 @@ def add_pretrain_args(parser: argparse.ArgumentParser):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def add_tokenizer_args(parser: argparse.ArgumentParser):
|
|
||||||
group = parser.add_argument_group("tokenizer", "tokenizer configurations")
|
|
||||||
group.add_argument(
|
|
||||||
"--tokenizer_path",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="tokenizer_path",
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def add_finetune_args(parser: argparse.ArgumentParser):
|
def add_finetune_args(parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group("finetune", "finetune configurations")
|
group = parser.add_argument_group("finetune", "finetune configurations")
|
||||||
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
|
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
|
||||||
|
@ -308,53 +204,14 @@ def add_feedback_learning_args(parser: argparse.ArgumentParser):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def add_model_change_args(parser: argparse.ArgumentParser):
|
def add_delta_args(parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group("model_change", "model change during pretraining")
|
group = parser.add_argument_group("LoRA","LoRA configurations")
|
||||||
group.add_argument("--strict_state_dict", type=int, default=1, help="strict_state_dict")
|
group.add_argument("--delta-type", type=str, default=None, help="delta-tuning-type")
|
||||||
##
|
group.add_argument("--lora-r", type=int, default=8, help="lora-rank")
|
||||||
return parser
|
group.add_argument("--lora-alpha", type=int, default=8, help="lora-alpha")
|
||||||
|
group.add_argument("--lora-dropout", type=float, default=0.0, help="lora-dropout")
|
||||||
|
group.add_argument("--lora-layer", nargs='+', default=['project_q','project_k'], help="lora-layer")
|
||||||
def add_log_args(parser: argparse.ArgumentParser):
|
group.add_argument("--save-origin-model", action="store_true", default=False)
|
||||||
group = parser.add_argument_group("log", "log configurations")
|
|
||||||
group.add_argument("--tensorboard_all_tasks", type=int, default=0, help="log")
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def add_error_handle_args(parser: argparse.ArgumentParser):
|
|
||||||
group = parser.add_argument_group("error_handle", "error_handle configurations")
|
|
||||||
group.add_argument(
|
|
||||||
"--ignore_cuda_oom", type=int, default=1, help="continue training by ingore the batch that causes oom"
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def add_runtime_eval_args(parser: argparse.ArgumentParser):
|
|
||||||
group = parser.add_argument_group("runtime eval args", "runtime evaluation by submitting a job")
|
|
||||||
group.add_argument(
|
|
||||||
"--runtime_eval",
|
|
||||||
action="store_true",
|
|
||||||
help="whether to use runtime_eval. Only if this is set to True, the following variables will be useful",
|
|
||||||
)
|
|
||||||
group.add_argument("--eval_jeeves_auth", type=str, default="", help="auth, press f12 on jeeves platform to get")
|
|
||||||
group.add_argument("--eval_project_id", type=str, default=None, help="project id")
|
|
||||||
group.add_argument("--eval_run_cmd", type=str, default="", help="cmd for eval")
|
|
||||||
group.add_argument(
|
|
||||||
"--eval_git_path",
|
|
||||||
type=str,
|
|
||||||
default="git@git.in.zhihu.com:luca/llm-bench.git",
|
|
||||||
help="git path of evaluation code",
|
|
||||||
)
|
|
||||||
group.add_argument("--eval_git_branch", type=str, default="master", help="git branch of evaluation code")
|
|
||||||
group.add_argument("--eval_node_num", type=int, default=1, help="using 1 node to evaluate")
|
|
||||||
group.add_argument("--eval_gpu_num", type=int, default=1, help="using 1 gpu per node to evaluate")
|
|
||||||
group.add_argument("--eval_tasks_config", type=str, default="", help="evaluate tasks' config")
|
|
||||||
group.add_argument("--eval_model_backend", default="torch", type=str, help="model_backend")
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--eval_at_start", action="store_true", help="whether to eval at the first epoch, default to false"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -365,30 +222,6 @@ def add_reward_args(parser: argparse.ArgumentParser):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def add_long_context_extend_args(parser: argparse.ArgumentParser):
|
|
||||||
"""long context extending arguments."""
|
|
||||||
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
|
|
||||||
group.add_argument("--pose-prob", default=0.0, type=float, help="Sample-level PoSE probability")
|
|
||||||
group.add_argument(
|
|
||||||
"--pose-scaling-factor",
|
|
||||||
default=1.0,
|
|
||||||
type=float,
|
|
||||||
help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--rope-scaling-type",
|
|
||||||
default="",
|
|
||||||
type=str,
|
|
||||||
choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],
|
|
||||||
help="Context scaling type",
|
|
||||||
)
|
|
||||||
group.add_argument("--rope-scaling-factor", default=1, type=int, help="Context scaling factor")
|
|
||||||
group.add_argument(
|
|
||||||
"--orig-max-length", default=8192, type=int, help="Original context length before context extending"
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_args(
|
def get_args(
|
||||||
pretrain: bool = False,
|
pretrain: bool = False,
|
||||||
finetune: bool = False,
|
finetune: bool = False,
|
||||||
|
@ -402,14 +235,9 @@ def get_args(
|
||||||
parser = add_training_args(parser)
|
parser = add_training_args(parser)
|
||||||
if pretrain:
|
if pretrain:
|
||||||
parser = add_pretrain_args(parser)
|
parser = add_pretrain_args(parser)
|
||||||
parser = add_runtime_eval_args(parser)
|
|
||||||
parser = add_tokenizer_args(parser)
|
|
||||||
parser = add_log_args(parser)
|
|
||||||
parser = add_error_handle_args(parser)
|
|
||||||
parser = add_model_change_args(parser)
|
|
||||||
|
|
||||||
if finetune:
|
if finetune:
|
||||||
parser = add_finetune_args(parser)
|
parser = add_finetune_args(parser)
|
||||||
|
parser = add_delta_args(parser)
|
||||||
if rhlf:
|
if rhlf:
|
||||||
parser = add_rhlf_args(parser)
|
parser = add_rhlf_args(parser)
|
||||||
if simple_rlhf:
|
if simple_rlhf:
|
||||||
|
@ -418,7 +246,6 @@ def get_args(
|
||||||
parser = add_feedback_learning_args(parser)
|
parser = add_feedback_learning_args(parser)
|
||||||
if reward:
|
if reward:
|
||||||
parser = add_reward_args(parser)
|
parser = add_reward_args(parser)
|
||||||
parser = add_long_context_extend_args(parser)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
from .models import CPM9G
|
||||||
|
from .models import CPM9GConfig
|
||||||
|
|
||||||
|
from .tokenizers import CPM9GTokenizer
|
||||||
|
from .training_tasks import MixedDataset
|
|
@ -0,0 +1,16 @@
|
||||||
|
{
|
||||||
|
"vocab_size": 119696,
|
||||||
|
"dropout_p": 0.0,
|
||||||
|
"eps": 1e-05,
|
||||||
|
"half": true,
|
||||||
|
"use_flash_attn": false,
|
||||||
|
"flash_attn_mask_shape": "1d",
|
||||||
|
"dim_model": 4096,
|
||||||
|
"dim_ff": 12288,
|
||||||
|
"dim_head": 128,
|
||||||
|
"num_heads": 32,
|
||||||
|
"num_kv_heads": 32,
|
||||||
|
"num_layers": 32,
|
||||||
|
"activate_fn": "silu",
|
||||||
|
"scale": false
|
||||||
|
}
|
|
@ -0,0 +1,658 @@
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ...generation import apply_repetition_penalty
|
||||||
|
from ...generation import BeamHypotheses
|
||||||
|
from ...generation import top_k_top_p_filtering
|
||||||
|
|
||||||
|
from ...utils import pad
|
||||||
|
from ..models import CPM9GTorch
|
||||||
|
from ..tokenizers.cpm9g import CPM9GTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GGeneration:
|
||||||
|
def __init__(
|
||||||
|
self, model: CPM9GTorch, tokenizer: CPM9GTokenizer, max_in_len=1024, use_nbce: bool = False
|
||||||
|
):
|
||||||
|
model.eval()
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_in_len = max_in_len
|
||||||
|
|
||||||
|
def _convert_to_tensors(self, data: Any):
|
||||||
|
input_ids = self.tokenizer.encode(
|
||||||
|
data["input"]
|
||||||
|
) # [self.tokenizer.bos_token_id] + self.tokenizer.encode(data["input"])
|
||||||
|
model_input = {}
|
||||||
|
|
||||||
|
model_input["input_ids"] = torch.tensor(input_ids[: self.max_in_len], dtype=torch.int32).unsqueeze(0)
|
||||||
|
model_input["context"] = torch.zeros(
|
||||||
|
(model_input["input_ids"].shape[0], model_input["input_ids"].shape[1]), dtype=torch.int16
|
||||||
|
)
|
||||||
|
model_input["span"] = torch.ones((model_input["input_ids"].shape[1],), dtype=torch.int16).unsqueeze(0)
|
||||||
|
model_input["length"] = torch.tensor([model_input["input_ids"].shape[1]], dtype=torch.int16).unsqueeze(0)
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
def _process_list(self, data_list: List[Any]):
|
||||||
|
input_tensors = list(map(self._convert_to_tensors, data_list))
|
||||||
|
keys = set(input_tensors[0].keys())
|
||||||
|
padded = {}
|
||||||
|
for key in keys:
|
||||||
|
padded[key] = pad(input_tensors, key, padding_side="left").cuda()
|
||||||
|
return padded
|
||||||
|
|
||||||
|
def generate(self, data_list, **kwargs):
|
||||||
|
origin_data_list = data_list.copy()
|
||||||
|
model_inputs = self._process_list(data_list)
|
||||||
|
with torch.inference_mode():
|
||||||
|
result_ids = self._decode(model_inputs, **kwargs)
|
||||||
|
|
||||||
|
return result_ids
|
||||||
|
|
||||||
|
def _decode(self, model_inputs, **kwargs):
|
||||||
|
raise NotImplementedError("_decode is not implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GBeamSearch(CPM9GGeneration):
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
model_inputs,
|
||||||
|
beam_size=4,
|
||||||
|
max_length=100,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
repetition_window=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Beam search
|
||||||
|
Args:
|
||||||
|
model_inputs (dict): input ids.
|
||||||
|
beam_size (int, optional, defaults to 3): beam size of beam search.
|
||||||
|
generate_length (int, optional, defaults to 100): maximum generation length.
|
||||||
|
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||||
|
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||||
|
""" # noqa: E501
|
||||||
|
# generate_length + 1 for EOS token
|
||||||
|
max_length += 1
|
||||||
|
# expand dimmension
|
||||||
|
batch_size = model_inputs["input_ids"].size(0)
|
||||||
|
input: torch.Tensor = (
|
||||||
|
model_inputs["input_ids"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size, -1)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size * beam_size, -1)
|
||||||
|
)
|
||||||
|
length = (
|
||||||
|
model_inputs["length"]
|
||||||
|
.squeeze(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size)
|
||||||
|
.contiguous()
|
||||||
|
.view(
|
||||||
|
batch_size * beam_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
span: torch.Tensor = (
|
||||||
|
model_inputs["span"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size, -1)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size * beam_size, -1)
|
||||||
|
)
|
||||||
|
context: torch.Tensor = (
|
||||||
|
model_inputs["context"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size, -1)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size * beam_size, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
|
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input.device)
|
||||||
|
beam_scores[:, 1:] = -1e9
|
||||||
|
beam_scores = beam_scores.view(-1)
|
||||||
|
|
||||||
|
# generated hypotheses
|
||||||
|
generated_hyps = [
|
||||||
|
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
pred_start_index = input.size(-1)
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
|
for i in range(max_length + 1):
|
||||||
|
if i == 0:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input,
|
||||||
|
context=context,
|
||||||
|
span=span,
|
||||||
|
length=length,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input[:, -1:],
|
||||||
|
context=context,
|
||||||
|
span=span,
|
||||||
|
length=length,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
# skip all steps when we are done with each sentence
|
||||||
|
if all(done):
|
||||||
|
break
|
||||||
|
|
||||||
|
# (batch * beam, seqlen, model_dim)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
if i == 0:
|
||||||
|
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||||
|
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
||||||
|
apply_repetition_penalty(
|
||||||
|
logits,
|
||||||
|
batch_size,
|
||||||
|
beam_size,
|
||||||
|
input,
|
||||||
|
repetition_penalty,
|
||||||
|
pred_start_index,
|
||||||
|
input.size(-1) - 1,
|
||||||
|
repetition_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
|
||||||
|
|
||||||
|
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||||
|
next_scores = next_scores.view(batch_size, -1) # (batch_size, beam_size * vocab_size)
|
||||||
|
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
|
||||||
|
|
||||||
|
assert next_scores.size() == next_words.size() == (batch_size, 2 * beam_size)
|
||||||
|
next_batch_beam = []
|
||||||
|
|
||||||
|
for sent_id in range(batch_size):
|
||||||
|
# if we are done with this sentence
|
||||||
|
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
|
||||||
|
if done[sent_id]:
|
||||||
|
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
|
||||||
|
continue
|
||||||
|
|
||||||
|
# next sentence beam content
|
||||||
|
next_sent_beam = []
|
||||||
|
|
||||||
|
# next words for this sentence
|
||||||
|
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
||||||
|
# get beam and word IDs
|
||||||
|
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
|
||||||
|
word_id = idx % scores.size(-1)
|
||||||
|
|
||||||
|
# end of sentence, or next word
|
||||||
|
if word_id == self.tokenizer.bos_token_id or i == max_length:
|
||||||
|
generated_hyps[sent_id].add(
|
||||||
|
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
|
||||||
|
value.item(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
|
||||||
|
|
||||||
|
# the beam for next step is full
|
||||||
|
if len(next_sent_beam) == beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# update next beam content
|
||||||
|
assert len(next_sent_beam) == 0 if i == max_length else beam_size
|
||||||
|
if len(next_sent_beam) == 0:
|
||||||
|
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
|
||||||
|
next_batch_beam.extend(next_sent_beam)
|
||||||
|
assert len(next_batch_beam) == beam_size * (sent_id + 1)
|
||||||
|
|
||||||
|
# we have reached the last step
|
||||||
|
if i == max_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
# sanity check / prepare next batch
|
||||||
|
assert len(next_batch_beam) == batch_size * beam_size
|
||||||
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
|
beam_words = input.new([x[1] for x in next_batch_beam])
|
||||||
|
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
|
||||||
|
# re-order batch and internal states
|
||||||
|
input = input[beam_idx, :]
|
||||||
|
|
||||||
|
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
|
||||||
|
for key_value_layer in past_key_values["buffer"]:
|
||||||
|
if key_value_layer is not None:
|
||||||
|
key_value_layer[0] = key_value_layer[0][beam_idx]
|
||||||
|
key_value_layer[1] = key_value_layer[1][beam_idx]
|
||||||
|
|
||||||
|
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
|
||||||
|
context = torch.cat(
|
||||||
|
[context, context[:, -1:]],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
length += 1
|
||||||
|
|
||||||
|
span = torch.cat([span, span[:, -1:]], dim=-1)
|
||||||
|
|
||||||
|
# select the best hypotheses
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, hypotheses in enumerate(generated_hyps):
|
||||||
|
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||||
|
results.append(best_hyp)
|
||||||
|
|
||||||
|
result_text = list(map(self.tokenizer.decode, results))
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GRandomSampling(CPM9GGeneration):
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
model_inputs,
|
||||||
|
max_length=100,
|
||||||
|
top_k=0,
|
||||||
|
top_p=0.9,
|
||||||
|
temperature=0.9,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
repetition_window=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Top-k and top-p sampling.
|
||||||
|
Args:
|
||||||
|
model_inputs (dict): input ids
|
||||||
|
generate_length (int, optional, defaults to 100): maximum generation length
|
||||||
|
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
|
||||||
|
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
|
||||||
|
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
|
||||||
|
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||||
|
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||||
|
""" # noqa: E501
|
||||||
|
# generate_length + 1 for EOS token
|
||||||
|
max_length += 1
|
||||||
|
|
||||||
|
input = model_inputs["input_ids"]
|
||||||
|
context = model_inputs["context"]
|
||||||
|
|
||||||
|
length = model_inputs["length"].squeeze(1)
|
||||||
|
span = model_inputs["span"]
|
||||||
|
batch_size = input.size(0)
|
||||||
|
|
||||||
|
pred_start_index = input.size(-1)
|
||||||
|
past_key_values = None
|
||||||
|
done = [False for _ in range(batch_size)]
|
||||||
|
results = [None for _ in range(batch_size)]
|
||||||
|
for i in range(max_length):
|
||||||
|
if i == 0:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input,
|
||||||
|
context=context,
|
||||||
|
length=length,
|
||||||
|
span=span,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input[:, -1:],
|
||||||
|
context=context,
|
||||||
|
length=length,
|
||||||
|
span=span,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||||
|
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
||||||
|
|
||||||
|
apply_repetition_penalty(
|
||||||
|
logits,
|
||||||
|
batch_size,
|
||||||
|
1,
|
||||||
|
input,
|
||||||
|
repetition_penalty,
|
||||||
|
pred_start_index,
|
||||||
|
input.size(-1) - 1,
|
||||||
|
repetition_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = logits / temperature
|
||||||
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
||||||
|
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
for idx in range(batch_size):
|
||||||
|
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
|
||||||
|
done[idx] = True
|
||||||
|
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
|
||||||
|
|
||||||
|
if sum(done) == batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# update input ids
|
||||||
|
input = torch.cat([input, next_token], dim=-1)
|
||||||
|
length += 1
|
||||||
|
|
||||||
|
context = torch.cat(
|
||||||
|
[context, context[:, -1:]],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
span = torch.cat(
|
||||||
|
[span, span[:, -1:]],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_text = list(map(self.tokenizer.decode, results))
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GBeamSearchNBCE(CPM9GGeneration):
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
model_inputs,
|
||||||
|
beam_size=5,
|
||||||
|
max_length=100,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
repetition_window=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Beam search
|
||||||
|
Args:
|
||||||
|
model_inputs (dict): input ids.
|
||||||
|
beam_size (int, optional, defaults to 3): beam size of beam search.
|
||||||
|
generate_length (int, optional, defaults to 100): maximum generation length.
|
||||||
|
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||||
|
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||||
|
""" # noqa: E501
|
||||||
|
# generate_length + 1 for EOS token
|
||||||
|
max_length += 1
|
||||||
|
|
||||||
|
# expand dimmension
|
||||||
|
batch_size = model_inputs["input_ids"].size(0)
|
||||||
|
input: torch.Tensor = (
|
||||||
|
model_inputs["input_ids"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size, -1)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size * beam_size, -1)
|
||||||
|
)
|
||||||
|
length = (
|
||||||
|
model_inputs["length"]
|
||||||
|
.squeeze(1)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size)
|
||||||
|
.contiguous()
|
||||||
|
.view(
|
||||||
|
batch_size * beam_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
span: torch.Tensor = (
|
||||||
|
model_inputs["span"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size, -1)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size * beam_size, -1)
|
||||||
|
)
|
||||||
|
context: torch.Tensor = (
|
||||||
|
model_inputs["context"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, beam_size, -1)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size * beam_size, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
done = [False]
|
||||||
|
|
||||||
|
beam_scores = torch.zeros((1, beam_size), dtype=torch.float, device=input.device)
|
||||||
|
beam_scores[:, 1:] = -1e9
|
||||||
|
beam_scores = beam_scores.view(-1)
|
||||||
|
|
||||||
|
# generated hypotheses
|
||||||
|
generated_hyps = [
|
||||||
|
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(1)
|
||||||
|
]
|
||||||
|
|
||||||
|
pred_start_index = input.size(-1)
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
|
for i in range(max_length + 1):
|
||||||
|
if i == 0:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input,
|
||||||
|
context=context,
|
||||||
|
span=span,
|
||||||
|
length=length,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input[:, -1:],
|
||||||
|
context=context,
|
||||||
|
span=span,
|
||||||
|
length=length,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
# skip all steps when we are done with each sentence
|
||||||
|
|
||||||
|
if all(done):
|
||||||
|
break
|
||||||
|
|
||||||
|
# (batch * beam, seqlen, model_dim)
|
||||||
|
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
|
||||||
|
logits = NBCE(logits) # [vocab_size]
|
||||||
|
logits = logits.tile(beam_size, 1)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||||
|
apply_repetition_penalty(
|
||||||
|
logits,
|
||||||
|
1,
|
||||||
|
beam_size,
|
||||||
|
input,
|
||||||
|
repetition_penalty,
|
||||||
|
pred_start_index,
|
||||||
|
input.size(-1) - 1,
|
||||||
|
repetition_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
|
||||||
|
|
||||||
|
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||||
|
next_scores = next_scores.view(1, -1) # (batch_size, beam_size * vocab_size)
|
||||||
|
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
|
||||||
|
|
||||||
|
assert next_scores.size() == next_words.size() == (1, 2 * beam_size)
|
||||||
|
next_batch_beam = []
|
||||||
|
|
||||||
|
for sent_id in range(1):
|
||||||
|
# if we are done with this sentence
|
||||||
|
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
|
||||||
|
if done[sent_id]:
|
||||||
|
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
|
||||||
|
continue
|
||||||
|
|
||||||
|
# next sentence beam content
|
||||||
|
next_sent_beam = []
|
||||||
|
|
||||||
|
# next words for this sentence
|
||||||
|
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
||||||
|
# get beam and word IDs
|
||||||
|
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
|
||||||
|
word_id = idx % scores.size(-1)
|
||||||
|
|
||||||
|
# end of sentence, or next word
|
||||||
|
if word_id == self.tokenizer.bos_token_id or i == max_length:
|
||||||
|
generated_hyps[sent_id].add(
|
||||||
|
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
|
||||||
|
value.item(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
|
||||||
|
|
||||||
|
# the beam for next step is full
|
||||||
|
if len(next_sent_beam) == beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# update next beam content
|
||||||
|
assert len(next_sent_beam) == 0 if i == max_length else beam_size
|
||||||
|
if len(next_sent_beam) == 0:
|
||||||
|
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
|
||||||
|
next_batch_beam.extend(next_sent_beam)
|
||||||
|
assert len(next_batch_beam) == beam_size * (sent_id + 1)
|
||||||
|
|
||||||
|
# we have reached the last step
|
||||||
|
if i == max_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
# sanity check / prepare next batch
|
||||||
|
assert len(next_batch_beam) == batch_size * beam_size
|
||||||
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||||
|
beam_words = input.new([x[1] for x in next_batch_beam])
|
||||||
|
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
|
||||||
|
beam_idx *= batch_size
|
||||||
|
# re-order batch and internal states
|
||||||
|
input = input[beam_idx, :]
|
||||||
|
|
||||||
|
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
|
||||||
|
for key_value_layer in past_key_values["buffer"]:
|
||||||
|
if key_value_layer is not None:
|
||||||
|
key_value_layer[0] = key_value_layer[0][beam_idx]
|
||||||
|
key_value_layer[1] = key_value_layer[1][beam_idx]
|
||||||
|
|
||||||
|
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
|
||||||
|
context = torch.cat(
|
||||||
|
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
length = past_key_values["buffer_length"]
|
||||||
|
length = torch.cat(
|
||||||
|
[length, torch.ones((length.size(0), 1), dtype=torch.int16, device=length.device)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
span = torch.cat([span, span[:, -1:]], dim=-1)
|
||||||
|
|
||||||
|
# select the best hypotheses
|
||||||
|
results = []
|
||||||
|
for i, hypotheses in enumerate(generated_hyps):
|
||||||
|
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||||
|
results.append(best_hyp)
|
||||||
|
|
||||||
|
result_text = list(map(self.tokenizer.decode, results))
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GRandomSamplingNBCE(CPM9GGeneration):
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
model_inputs,
|
||||||
|
max_length=100,
|
||||||
|
top_k=0,
|
||||||
|
top_p=0.9,
|
||||||
|
temperature=0.9,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
repetition_window=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Top-k and top-p sampling.
|
||||||
|
Args:
|
||||||
|
model_inputs (dict): input ids
|
||||||
|
generate_length (int, optional, defaults to 100): maximum generation length
|
||||||
|
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
|
||||||
|
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
|
||||||
|
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
|
||||||
|
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
|
||||||
|
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
|
||||||
|
""" # noqa: E501
|
||||||
|
# generate_length + 1 for EOS token
|
||||||
|
max_length += 1
|
||||||
|
|
||||||
|
input = model_inputs["input_ids"]
|
||||||
|
context = model_inputs["context"]
|
||||||
|
|
||||||
|
length = model_inputs["length"].squeeze(1)
|
||||||
|
span = model_inputs["span"]
|
||||||
|
batch_size = input.size(0)
|
||||||
|
|
||||||
|
pred_start_index = input.size(-1)
|
||||||
|
past_key_values = None
|
||||||
|
done = [False]
|
||||||
|
results = [None]
|
||||||
|
for i in range(max_length):
|
||||||
|
if i == 0:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input,
|
||||||
|
context=context,
|
||||||
|
length=length,
|
||||||
|
span=span,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, _, past_key_values = self.model.inference(
|
||||||
|
input=input[:, -1:],
|
||||||
|
context=context,
|
||||||
|
length=length,
|
||||||
|
span=span,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
|
||||||
|
logits = NBCE(logits) # [vocab_size]
|
||||||
|
logits = logits[None]
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
logits[:, self.tokenizer.bos_token_id] = -float("inf")
|
||||||
|
# logits[:, self.tokenizer.newline_id] = -float("inf")
|
||||||
|
|
||||||
|
apply_repetition_penalty(
|
||||||
|
logits,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
input,
|
||||||
|
repetition_penalty,
|
||||||
|
pred_start_index,
|
||||||
|
input.size(-1) - 1,
|
||||||
|
repetition_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = logits / temperature
|
||||||
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
||||||
|
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
for idx in range(1):
|
||||||
|
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
|
||||||
|
done[idx] = True
|
||||||
|
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
|
||||||
|
|
||||||
|
if sum(done) == 1:
|
||||||
|
break
|
||||||
|
next_token = next_token.tile(batch_size, 1)
|
||||||
|
# update input ids
|
||||||
|
input = torch.cat([input, next_token], dim=-1)
|
||||||
|
length = past_key_values["buffer_length"]
|
||||||
|
length = torch.cat(
|
||||||
|
[length, torch.ones((length.size(0), 1), dtype=torch.int32, device=length.device)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
# length += 1
|
||||||
|
context = torch.cat(
|
||||||
|
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
span = torch.cat(
|
||||||
|
[span, torch.zeros((span.size(0), 1), dtype=torch.int32, device=span.device)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_text = list(map(self.tokenizer.decode, results))
|
||||||
|
return result_text
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .cpm9g import CPM9G
|
||||||
|
from .cpm9g import CPM9GConfig
|
||||||
|
from .cpm9g_torch import CPM9GTorch
|
|
@ -0,0 +1,272 @@
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from ...layers import Embedding
|
||||||
|
from ...layers import Encoder
|
||||||
|
from ...layers import RotaryEmbeddingESM
|
||||||
|
from ...utils import Config
|
||||||
|
from ...utils import gradient_shrink
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GInferenceState(TypedDict):
|
||||||
|
buffer_context: torch.Tensor
|
||||||
|
buffer_sample_ids: torch.Tensor
|
||||||
|
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GConfig(Config):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
dim_model=4096,
|
||||||
|
num_heads=32,
|
||||||
|
num_kv_heads=32,
|
||||||
|
dim_head=128,
|
||||||
|
dim_ff=11008,
|
||||||
|
num_layers=32,
|
||||||
|
dropout_p=0.0,
|
||||||
|
activate_fn="silu",
|
||||||
|
scale=True,
|
||||||
|
eps=1e-5,
|
||||||
|
half: bool = True,
|
||||||
|
bf16: bool = False,
|
||||||
|
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
||||||
|
use_flash_attn: bool = True,
|
||||||
|
flash_attn_mask_shape="1d",
|
||||||
|
flash_impl="cuda",
|
||||||
|
base=10000,
|
||||||
|
tp=0,
|
||||||
|
disabled_checkpoint=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.dim_ff = dim_ff
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dropout_p = dropout_p
|
||||||
|
self.activate_fn = activate_fn
|
||||||
|
self.scale = scale
|
||||||
|
self.eps = eps
|
||||||
|
if half:
|
||||||
|
if bf16:
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
self.dtype = torch.half
|
||||||
|
else:
|
||||||
|
self.dtype = torch.float
|
||||||
|
self.flash_impl = flash_impl
|
||||||
|
self.mask_modules = mask_modules
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
self.flash_attn_mask_shape = flash_attn_mask_shape
|
||||||
|
self.base = base
|
||||||
|
self.tp = tp
|
||||||
|
self.disabled_checkpoint = disabled_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9G(bmt.DistributedModule):
|
||||||
|
def __init__(self, config: CPM9GConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
num_layers=config.num_layers,
|
||||||
|
dim_model=config.dim_model,
|
||||||
|
dim_ff=config.dim_ff,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
num_kv_heads=config.num_kv_heads,
|
||||||
|
dim_head=config.dim_head,
|
||||||
|
activate_fn=config.activate_fn,
|
||||||
|
dtype=config.dtype,
|
||||||
|
eps=config.eps,
|
||||||
|
dropout_p=config.dropout_p,
|
||||||
|
scale=config.scale,
|
||||||
|
mask_modules=config.mask_modules,
|
||||||
|
use_flash_attn=config.use_flash_attn,
|
||||||
|
tp=config.tp,
|
||||||
|
disabled_checkpoint=config.disabled_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_embedding = Embedding(
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
embedding_size=config.dim_model,
|
||||||
|
scale=config.scale,
|
||||||
|
dtype=config.dtype,
|
||||||
|
init_std=0.02,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.position_bias = RotaryEmbeddingESM(
|
||||||
|
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lm_head = Embedding(
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
embedding_size=config.dim_model,
|
||||||
|
scale=config.scale,
|
||||||
|
dtype=config.dtype,
|
||||||
|
init_std=0.02,
|
||||||
|
)
|
||||||
|
self.flash_impl = config.flash_impl
|
||||||
|
self.use_flash_attn = config.use_flash_attn
|
||||||
|
self.flash_attn_mask_shape = config.flash_attn_mask_shape
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor, # (batch, seqlen) int32
|
||||||
|
length: torch.Tensor = None, # (batch) int32
|
||||||
|
context: torch.Tensor = None, # (batch, seqlen) bool
|
||||||
|
span: torch.Tensor = None, # (batch, seqlen) int32
|
||||||
|
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
|
||||||
|
max_seqlen: int = None,
|
||||||
|
position_ids: torch.Tensor = None, # (batch, seqlen) int32
|
||||||
|
):
|
||||||
|
batch = input.size(0)
|
||||||
|
seqlen = input.size(1)
|
||||||
|
device = input.device
|
||||||
|
|
||||||
|
if length is not None and length.dim() == 1:
|
||||||
|
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
||||||
|
|
||||||
|
# processing masks and position bias bucket
|
||||||
|
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
|
||||||
|
with torch.no_grad():
|
||||||
|
# directional mask
|
||||||
|
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
|
||||||
|
-1, 1
|
||||||
|
)
|
||||||
|
# context mask
|
||||||
|
attention_mask = context[:, None, :] | (
|
||||||
|
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||||
|
)
|
||||||
|
# span mask
|
||||||
|
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||||
|
# length mask
|
||||||
|
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||||
|
else:
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
hidden_states = self.input_embedding(input)
|
||||||
|
|
||||||
|
if self.config.tp:
|
||||||
|
with torch.no_grad():
|
||||||
|
if length is not None:
|
||||||
|
length = bmt.distributed.all_gather(length, comm=bmt.config["tp_comm"]).flatten(0, 1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = bmt.distributed.all_gather(attention_mask, comm=bmt.config["tp_comm"]).flatten(
|
||||||
|
0, 1
|
||||||
|
)
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
lens = bmt.distributed.all_gather(
|
||||||
|
torch.tensor([cu_seqlens.numel()]).cuda(), comm=bmt.config["tp_comm"]
|
||||||
|
)
|
||||||
|
mx = lens.max().item()
|
||||||
|
cu_seq = torch.zeros(mx).int().cuda()
|
||||||
|
cu_seq[: cu_seqlens.numel()] = cu_seqlens
|
||||||
|
all_seq = bmt.distributed.all_gather(cu_seq, comm=bmt.config["tp_comm"])
|
||||||
|
cu_seqlens = [0]
|
||||||
|
for i, l in enumerate(lens):
|
||||||
|
for j in range(1, l - 1):
|
||||||
|
cu_seqlens.append(all_seq[i][j] + i * seqlen)
|
||||||
|
cu_seqlens.append((i + 1) * seqlen)
|
||||||
|
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).cuda()
|
||||||
|
if max_seqlen is not None:
|
||||||
|
max_seqlen = bmt.distributed.all_reduce(
|
||||||
|
torch.tensor([max_seqlen]).cuda(), "max", comm=bmt.config["tp_comm"]
|
||||||
|
)[0].item()
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = bmt.distributed.all_gather(position_ids, comm=bmt.config["tp_comm"]).flatten(0, 1)
|
||||||
|
|
||||||
|
if self.use_flash_attn:
|
||||||
|
if self.flash_attn_mask_shape == "1d":
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_bias=self.position_bias,
|
||||||
|
pos_bias_type="rotary",
|
||||||
|
length_mask=length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.flash_impl == "triton":
|
||||||
|
mask = attention_mask.unsqueeze(dim=1).contiguous()
|
||||||
|
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
|
||||||
|
attention_mask_bias[mask == False] -= torch.inf
|
||||||
|
else:
|
||||||
|
attention_mask_bias = None
|
||||||
|
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_bias=self.position_bias,
|
||||||
|
pos_bias_type="rotary",
|
||||||
|
length_mask=None,
|
||||||
|
attention_mask_bias=attention_mask_bias,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head.projection(hidden_states)
|
||||||
|
|
||||||
|
return logits, hidden_states
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor, # (batch, len_q) int32
|
||||||
|
length: torch.Tensor, # (batch) int32
|
||||||
|
context: torch.Tensor, # (batch, seqlen) int16
|
||||||
|
span: torch.Tensor, # (batch, seqlen) int32
|
||||||
|
past_key_values: Optional[CPM9GInferenceState] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
|
||||||
|
batch = input.size(0)
|
||||||
|
len_q = input.size(1)
|
||||||
|
len_buffer = 0
|
||||||
|
if past_key_values is None:
|
||||||
|
present_buffer = None
|
||||||
|
else:
|
||||||
|
present_buffer = past_key_values["buffer"]
|
||||||
|
len_buffer = present_buffer[0][0].shape[-2]
|
||||||
|
seqlen = len_buffer + len_q
|
||||||
|
with torch.no_grad():
|
||||||
|
device = input.device
|
||||||
|
if length.dim() == 1:
|
||||||
|
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
||||||
|
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
||||||
|
# context mask
|
||||||
|
attention_mask = context[:, None, :] | (
|
||||||
|
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||||
|
)
|
||||||
|
# span mask
|
||||||
|
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||||
|
# length mask
|
||||||
|
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||||
|
|
||||||
|
hidden_states = self.input_embedding(input)
|
||||||
|
|
||||||
|
hidden_states, present_key_values, _ = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask[:, len_buffer:],
|
||||||
|
position_bias=self.position_bias,
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=present_buffer,
|
||||||
|
pos_bias_type="rotary",
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head.projection(hidden_states)
|
||||||
|
|
||||||
|
return (
|
||||||
|
logits,
|
||||||
|
hidden_states,
|
||||||
|
{"buffer": present_key_values},
|
||||||
|
)
|
|
@ -0,0 +1,186 @@
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from ...native_layers import Embedding
|
||||||
|
from ...native_layers import Encoder
|
||||||
|
from ...native_layers import RotaryEmbeddingESM
|
||||||
|
from ...utils import Config
|
||||||
|
from ...utils import gradient_shrink
|
||||||
|
from .cpm9g import CPM9GConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GInferenceState(TypedDict):
|
||||||
|
buffer_context: torch.Tensor
|
||||||
|
buffer_sample_ids: torch.Tensor
|
||||||
|
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GTorch(torch.nn.Module):
|
||||||
|
def __init__(self, config: CPM9GConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
num_layers=config.num_layers,
|
||||||
|
dim_model=config.dim_model,
|
||||||
|
dim_ff=config.dim_ff,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
num_kv_heads=config.num_kv_heads,
|
||||||
|
dim_head=config.dim_head,
|
||||||
|
activate_fn=config.activate_fn,
|
||||||
|
dtype=config.dtype,
|
||||||
|
eps=config.eps,
|
||||||
|
dropout_p=config.dropout_p,
|
||||||
|
scale=config.scale,
|
||||||
|
mask_modules=config.mask_modules,
|
||||||
|
use_flash_attn=config.use_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_embedding = Embedding(
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
embedding_size=config.dim_model,
|
||||||
|
scale=config.scale,
|
||||||
|
dtype=config.dtype,
|
||||||
|
init_std=0.02,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.position_bias = RotaryEmbeddingESM(
|
||||||
|
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lm_head = Embedding(
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
embedding_size=config.dim_model,
|
||||||
|
scale=config.scale,
|
||||||
|
dtype=config.dtype,
|
||||||
|
init_std=0.02,
|
||||||
|
)
|
||||||
|
self.flash_impl = False
|
||||||
|
self.use_flash_attn = False
|
||||||
|
self.flash_attn_mask_shape = "1d"
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor, # (batch, seqlen) int32
|
||||||
|
length: torch.Tensor = None, # (batch) int32
|
||||||
|
context: torch.Tensor = None, # (batch, seqlen) bool
|
||||||
|
span: torch.Tensor = None, # (batch, seqlen) int32
|
||||||
|
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
|
||||||
|
max_seqlen: int = None,
|
||||||
|
position_ids: torch.Tensor = None, # (batch, seqlen) int32
|
||||||
|
):
|
||||||
|
batch = input.size(0)
|
||||||
|
seqlen = input.size(1)
|
||||||
|
device = input.device
|
||||||
|
|
||||||
|
if length is not None and length.dim() == 1:
|
||||||
|
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
|
||||||
|
|
||||||
|
# processing masks and position bias bucket
|
||||||
|
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
|
||||||
|
with torch.no_grad():
|
||||||
|
# directional mask
|
||||||
|
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
|
||||||
|
-1, 1
|
||||||
|
)
|
||||||
|
# context mask
|
||||||
|
attention_mask = context[:, None, :] | (
|
||||||
|
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||||
|
)
|
||||||
|
# span mask
|
||||||
|
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||||
|
# length mask
|
||||||
|
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||||
|
|
||||||
|
hidden_states = self.input_embedding(input)
|
||||||
|
|
||||||
|
if self.use_flash_attn:
|
||||||
|
if self.flash_attn_mask_shape == "1d":
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_bias=self.position_bias,
|
||||||
|
pos_bias_type="rotary",
|
||||||
|
length_mask=length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.flash_impl == "triton":
|
||||||
|
mask = attention_mask.unsqueeze(dim=1).contiguous()
|
||||||
|
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
|
||||||
|
attention_mask_bias[mask == False] -= torch.inf
|
||||||
|
else:
|
||||||
|
attention_mask_bias = None
|
||||||
|
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_bias=self.position_bias,
|
||||||
|
pos_bias_type="rotary",
|
||||||
|
length_mask=None,
|
||||||
|
attention_mask_bias=attention_mask_bias,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head.projection(hidden_states)
|
||||||
|
|
||||||
|
return logits, hidden_states
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor, # (batch, len_q) int32
|
||||||
|
length: torch.Tensor, # (batch) int32
|
||||||
|
context: torch.Tensor, # (batch, seqlen) int16
|
||||||
|
span: torch.Tensor, # (batch, seqlen) int32
|
||||||
|
past_key_values: Optional[CPM9GInferenceState] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
|
||||||
|
batch = input.size(0)
|
||||||
|
len_q = input.size(1)
|
||||||
|
len_buffer = 0
|
||||||
|
if past_key_values is None:
|
||||||
|
present_buffer = None
|
||||||
|
else:
|
||||||
|
present_buffer = past_key_values["buffer"]
|
||||||
|
len_buffer = present_buffer[0][0].shape[-2]
|
||||||
|
seqlen = len_buffer + len_q
|
||||||
|
with torch.no_grad():
|
||||||
|
device = input.device
|
||||||
|
if length.dim() == 1:
|
||||||
|
length = (torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) + length[:, None]) >= seqlen
|
||||||
|
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
||||||
|
# context mask
|
||||||
|
attention_mask = context[:, None, :] | (
|
||||||
|
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||||||
|
)
|
||||||
|
# span mask
|
||||||
|
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||||||
|
# length mask
|
||||||
|
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
|
||||||
|
|
||||||
|
hidden_states = self.input_embedding(input)
|
||||||
|
|
||||||
|
hidden_states, present_key_values, _ = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask[:, len_buffer:],
|
||||||
|
position_bias=self.position_bias,
|
||||||
|
use_cache=True,
|
||||||
|
past_key_values=present_buffer,
|
||||||
|
pos_bias_type="rotary",
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head.projection(hidden_states)
|
||||||
|
|
||||||
|
return (
|
||||||
|
logits,
|
||||||
|
hidden_states,
|
||||||
|
{"buffer": present_key_values},
|
||||||
|
)
|
|
@ -0,0 +1 @@
|
||||||
|
from .cpm9g import CPM9GTokenizer
|
|
@ -22,7 +22,7 @@ def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
class FM9GTokenizer(object):
|
class CPM9GTokenizer(object):
|
||||||
def __init__(self, path=None):
|
def __init__(self, path=None):
|
||||||
self.unk_token = "<unk>"
|
self.unk_token = "<unk>"
|
||||||
self.bos_token = "<s>"
|
self.bos_token = "<s>"
|
||||||
|
@ -36,7 +36,7 @@ class FM9GTokenizer(object):
|
||||||
if path:
|
if path:
|
||||||
all_tokens = load_vocab(io.FileIO(path, "rb"))
|
all_tokens = load_vocab(io.FileIO(path, "rb"))
|
||||||
else:
|
else:
|
||||||
all_tokens = load_vocab(pkg_resources.resource_stream("fm9g", "/fm9g/vocabs/fm9g.txt"))
|
all_tokens = load_vocab(pkg_resources.resource_stream("cpm", "/cpm9g/vocabs/cpm9g.txt"))
|
||||||
|
|
||||||
self.encoder: Dict[str, int] = {}
|
self.encoder: Dict[str, int] = {}
|
||||||
self._special_encoder: Dict[str, int] = {}
|
self._special_encoder: Dict[str, int] = {}
|
||||||
|
@ -106,8 +106,8 @@ class FM9GTokenizer(object):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def encode(self, text: str) -> List[int]:
|
def encode(self, text: str) -> List[int]:
|
||||||
#if len(text) > 20480:
|
if len(text) > 20480:
|
||||||
# return [0 for _ in range(20480)]
|
return [0 for _ in range(20480)]
|
||||||
ret = []
|
ret = []
|
||||||
for x in self.tokenize(text):
|
for x in self.tokenize(text):
|
||||||
if x in self.encoder:
|
if x in self.encoder:
|
||||||
|
@ -136,14 +136,9 @@ class FM9GTokenizer(object):
|
||||||
plane_id = self._byte_decoder[tokens[st + 1]]
|
plane_id = self._byte_decoder[tokens[st + 1]]
|
||||||
row_id = self._byte_decoder[tokens[st + 2]]
|
row_id = self._byte_decoder[tokens[st + 2]]
|
||||||
cell_id = self._byte_decoder[tokens[st + 3]]
|
cell_id = self._byte_decoder[tokens[st + 3]]
|
||||||
int_bytes = int.to_bytes(first_id << 24 | plane_id << 16 | row_id << 8 | cell_id, 4, "big")
|
ret.append(
|
||||||
try:
|
int.to_bytes(first_id << 24 | plane_id << 16 | row_id << 8 | cell_id, 4, "big").decode("utf-8")
|
||||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
)
|
||||||
ret.append(decoded_str)
|
|
||||||
#print(decoded_str)
|
|
||||||
except UnicodeDecodeError as e:
|
|
||||||
print(f"UnicodeDecodeError: {e}")
|
|
||||||
|
|
||||||
st += 4
|
st += 4
|
||||||
elif (
|
elif (
|
||||||
st + 2 < len(tokens)
|
st + 2 < len(tokens)
|
||||||
|
@ -153,33 +148,16 @@ class FM9GTokenizer(object):
|
||||||
plane_id = self._byte_decoder[tokens[st]]
|
plane_id = self._byte_decoder[tokens[st]]
|
||||||
row_id = self._byte_decoder[tokens[st + 1]]
|
row_id = self._byte_decoder[tokens[st + 1]]
|
||||||
cell_id = self._byte_decoder[tokens[st + 2]]
|
cell_id = self._byte_decoder[tokens[st + 2]]
|
||||||
int_bytes = int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big")
|
ret.append(int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big").decode("utf-8"))
|
||||||
try:
|
|
||||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
|
||||||
ret.append(decoded_str)
|
|
||||||
except UnicodeDecodeError as e:
|
|
||||||
print(f"UnicodeDecodeError: {e}")
|
|
||||||
st += 3
|
st += 3
|
||||||
elif st + 1 < len(tokens) and tokens[st + 1] in self._byte_decoder:
|
elif st + 1 < len(tokens) and tokens[st + 1] in self._byte_decoder:
|
||||||
row_id = self._byte_decoder[tokens[st]]
|
row_id = self._byte_decoder[tokens[st]]
|
||||||
cell_id = self._byte_decoder[tokens[st + 1]]
|
cell_id = self._byte_decoder[tokens[st + 1]]
|
||||||
int_bytes = int.to_bytes(row_id << 8 | cell_id, 2, "big")
|
ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8"))
|
||||||
try:
|
|
||||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
|
||||||
ret.append(decoded_str)
|
|
||||||
except UnicodeDecodeError as e:
|
|
||||||
print(f"UnicodeDecodeError: {e}")
|
|
||||||
#ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8"))
|
|
||||||
st += 2
|
st += 2
|
||||||
else:
|
else:
|
||||||
cell_id = self._byte_decoder[tokens[st]]
|
cell_id = self._byte_decoder[tokens[st]]
|
||||||
int_bytes = int.to_bytes(cell_id, 1, "big")
|
ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8"))
|
||||||
try:
|
|
||||||
decoded_str = int_bytes.decode("utf-8", errors="replace")
|
|
||||||
ret.append(decoded_str)
|
|
||||||
except UnicodeDecodeError as e:
|
|
||||||
print(f"UnicodeDecodeError: {e}")
|
|
||||||
#ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8"))
|
|
||||||
st += 1
|
st += 1
|
||||||
elif tokens[st] == self.eos_id:
|
elif tokens[st] == self.eos_id:
|
||||||
ret.append(self.eos_token)
|
ret.append(self.eos_token)
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .pretrain import MixedDataset
|
||||||
|
from .finetune import FinetuneDataset
|
|
@ -0,0 +1,87 @@
|
||||||
|
import bmtrain as bmt
|
||||||
|
|
||||||
|
from ...dataset import SimpleDataset
|
||||||
|
from ..tokenizers import CPM9GTokenizer
|
||||||
|
from .pretrain import _MixedDatasetBatchPacker
|
||||||
|
from .pretrain import _MixedDatasetConfig
|
||||||
|
from .pretrain import CPM9GBatch
|
||||||
|
|
||||||
|
|
||||||
|
class FinetuneDataset:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_path: str,
|
||||||
|
batch_size: int,
|
||||||
|
max_length: int,
|
||||||
|
tokenizer: CPM9GTokenizer,
|
||||||
|
unpad: bool = False,
|
||||||
|
task_name: str = "task",
|
||||||
|
drop_last: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self._world_size = bmt.world_size()
|
||||||
|
self._rank = bmt.rank()
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._unpad = unpad
|
||||||
|
self._max_length = max_length
|
||||||
|
|
||||||
|
self._packer = _MixedDatasetBatchPacker(batch_size * self._world_size, max_length, tokenizer, unpad)
|
||||||
|
self._drop_last = drop_last
|
||||||
|
|
||||||
|
ds = SimpleDataset(dataset_path, shuffle=False)
|
||||||
|
self._ds_cfg: _MixedDatasetConfig = {
|
||||||
|
"weight": 1.0,
|
||||||
|
"path": dataset_path,
|
||||||
|
"transforms": [],
|
||||||
|
"task_name": task_name,
|
||||||
|
"dataset_name": "finetune",
|
||||||
|
"incontext_weight": [1.0],
|
||||||
|
"lines": len(ds),
|
||||||
|
"dataset": ds,
|
||||||
|
}
|
||||||
|
self._nlines = len(ds)
|
||||||
|
self._nbytes = ds.get_bytes()
|
||||||
|
|
||||||
|
def __batch_iter(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
batch = self._packer.add_data(self._ds_cfg)
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
if batch is None:
|
||||||
|
continue
|
||||||
|
yield batch
|
||||||
|
if len(self._packer) > 0:
|
||||||
|
batch = self._packer.pack_batch(force=True)
|
||||||
|
if not self._drop_last:
|
||||||
|
yield batch
|
||||||
|
self._ds_cfg["dataset"]._repeat_times = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if not self._unpad:
|
||||||
|
batch_st = self._batch_size * self._rank
|
||||||
|
batch_end = self._batch_size * (self._rank + 1)
|
||||||
|
for batch in self.__batch_iter():
|
||||||
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
if batch_size <= batch_st:
|
||||||
|
yield None
|
||||||
|
else:
|
||||||
|
ret: CPM9GBatch = {
|
||||||
|
kw: val[batch_st:batch_end] # type: ignore
|
||||||
|
for kw, val in batch.items()
|
||||||
|
if kw not in ["task_names", "raw_data", "cu_seqlens", "max_seqlen"]
|
||||||
|
} # type: ignore
|
||||||
|
ret["task_names"] = batch["task_names"]
|
||||||
|
ret["raw_data"] = batch["raw_data"]
|
||||||
|
ret["cu_seqlens"] = batch["cu_seqlens"]
|
||||||
|
ret["max_seqlen"] = batch["max_seqlen"]
|
||||||
|
yield ret
|
||||||
|
else:
|
||||||
|
for batch in self.__batch_iter():
|
||||||
|
assert batch["inputs"].shape[0] == 1
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
iter_count = int(self._nbytes // (3.6 * self._batch_size * self._world_size * self._max_length))
|
||||||
|
if not self._drop_last:
|
||||||
|
iter_count += 1
|
||||||
|
return iter_count
|
|
@ -0,0 +1,736 @@
|
||||||
|
import importlib.machinery
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from collections import OrderedDict
|
||||||
|
from queue import Empty
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Set
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from ...dataset import DistributedDataset
|
||||||
|
from ..tokenizers import CPM9GTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class _MixedDatasetConfig(TypedDict):
|
||||||
|
weight: float
|
||||||
|
path: str
|
||||||
|
transforms: Union[List[Dict[str, Any]], str]
|
||||||
|
task_name: str
|
||||||
|
dataset_name: str
|
||||||
|
incontext_weight: List[float]
|
||||||
|
lines: int
|
||||||
|
dataset: DistributedDataset
|
||||||
|
|
||||||
|
|
||||||
|
CPM9GInputType = Union[str, Dict[str, "CPM9GInputType"]]
|
||||||
|
|
||||||
|
|
||||||
|
class _TransformFuncDict(TypedDict):
|
||||||
|
loader: importlib.machinery.SourceFileLoader
|
||||||
|
module: types.ModuleType
|
||||||
|
last_m: float
|
||||||
|
|
||||||
|
|
||||||
|
_TransformFunction = Callable[[CPM9GInputType, int, random.Random], CPM9GInputType]
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GBatch(TypedDict):
|
||||||
|
inputs: NDArray[np.int32]
|
||||||
|
length: NDArray[np.int32]
|
||||||
|
context: NDArray[np.bool_]
|
||||||
|
sample_ids: NDArray[np.int32]
|
||||||
|
spans: NDArray[np.int32]
|
||||||
|
target: NDArray[np.int32]
|
||||||
|
task_ids: NDArray[np.int32]
|
||||||
|
task_names: List[str]
|
||||||
|
raw_data: List[Any]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_data_to_id(tokenizer: CPM9GTokenizer, data: Any):
|
||||||
|
input_ids = tokenizer.encode(data["input"])
|
||||||
|
output_ids = tokenizer.encode(data["output"])
|
||||||
|
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
|
||||||
|
ids = np.array(ids, dtype=np.int32)
|
||||||
|
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
||||||
|
context[: len(input_ids) + 1] = 1
|
||||||
|
return ids, context
|
||||||
|
|
||||||
|
|
||||||
|
def _dataset_identity(c: _MixedDatasetConfig):
|
||||||
|
return "{}.{}".format(c["task_name"], c["dataset_name"])
|
||||||
|
|
||||||
|
|
||||||
|
class _MixedDatasetBatchPacker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
max_length: int,
|
||||||
|
tokenizer: CPM9GTokenizer,
|
||||||
|
unpad: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._max_length = max_length
|
||||||
|
self._clip_length = max_length
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self._transform_func_table: Dict[str, _TransformFuncDict] = {}
|
||||||
|
self._unpad = unpad
|
||||||
|
if unpad:
|
||||||
|
self._max_length = max_length * batch_size
|
||||||
|
self._batch_size = 1
|
||||||
|
|
||||||
|
self._inputs: List[NDArray[np.int32]] = []
|
||||||
|
self._context: List[NDArray[np.int8]] = []
|
||||||
|
self._sample_ids: List[NDArray[np.int32]] = []
|
||||||
|
self._spans: List[List[int]] = []
|
||||||
|
self._task_ids: List[List[str]] = []
|
||||||
|
self._raw_data: List[List[Any]] = []
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._inputs)
|
||||||
|
|
||||||
|
def apply_transform(
|
||||||
|
self,
|
||||||
|
data: CPM9GInputType,
|
||||||
|
transform: Union[Dict[str, Any], Callable[[CPM9GInputType], CPM9GInputType], None],
|
||||||
|
) -> CPM9GInputType:
|
||||||
|
if transform is None:
|
||||||
|
return data
|
||||||
|
if not isinstance(transform, dict):
|
||||||
|
return transform(data) # transform function
|
||||||
|
|
||||||
|
mapping_list: List[Tuple[str, str]] = []
|
||||||
|
|
||||||
|
def _walk_transform_dict(data: Union[Dict[str, Any], str], prefix: str = ""):
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for k, v in data.items():
|
||||||
|
if len(prefix) > 0:
|
||||||
|
_walk_transform_dict(v, prefix + "." + k)
|
||||||
|
else:
|
||||||
|
_walk_transform_dict(v, k)
|
||||||
|
else:
|
||||||
|
assert isinstance(data, str), "Invalid transform {}".format(data)
|
||||||
|
mapping_list.append((prefix, data))
|
||||||
|
|
||||||
|
_walk_transform_dict(transform)
|
||||||
|
|
||||||
|
expanded_mapping_list: List[Tuple[str, Any]] = []
|
||||||
|
|
||||||
|
def _expand_mapping(data: CPM9GInputType, stars: List[str], path: List[str], target: List[str]):
|
||||||
|
if len(path) == 0:
|
||||||
|
num_stars = 0
|
||||||
|
for it in target:
|
||||||
|
if it == "*":
|
||||||
|
num_stars += 1
|
||||||
|
if num_stars != len(stars):
|
||||||
|
raise ValueError("Invalid transform {}".format(".".join(target)))
|
||||||
|
|
||||||
|
nw_tgt = []
|
||||||
|
num_stars = 0
|
||||||
|
for it in target:
|
||||||
|
if it == "*":
|
||||||
|
nw_tgt.append(stars[num_stars])
|
||||||
|
num_stars += 1
|
||||||
|
else:
|
||||||
|
nw_tgt.append(it)
|
||||||
|
expanded_mapping_list.append((".".join(nw_tgt), data))
|
||||||
|
else:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("Invalid data {}".format(data))
|
||||||
|
if path[0] == "*":
|
||||||
|
for k, v in data.items():
|
||||||
|
_expand_mapping(v, stars + [k], path[1:], target)
|
||||||
|
else:
|
||||||
|
_expand_mapping(data[path[0]], stars, path[1:], target)
|
||||||
|
|
||||||
|
# expand mapping list
|
||||||
|
for tgt, src in mapping_list:
|
||||||
|
if src.startswith("$"):
|
||||||
|
# copy from src
|
||||||
|
_expand_mapping(data, [], src[1:].split("."), tgt.split("."))
|
||||||
|
else:
|
||||||
|
if "*" in tgt:
|
||||||
|
raise ValueError("Constant value is not allowed to have `*` in prefix")
|
||||||
|
expanded_mapping_list.append((tgt, src))
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for tgt, val in expanded_mapping_list:
|
||||||
|
tgt = tgt.split(".")
|
||||||
|
cur = ret
|
||||||
|
while len(tgt) > 1:
|
||||||
|
cur = cur[tgt[0]]
|
||||||
|
tgt = tgt[1:]
|
||||||
|
cur[tgt[0]] = val
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def data_to_id(self, data: Any):
|
||||||
|
return convert_data_to_id(self.tokenizer, data)
|
||||||
|
|
||||||
|
def _ensure_transform_function(self, module_name: str, transform_script_path: str) -> _TransformFunction:
|
||||||
|
module_name = "cpm_live.transforms.{}".format(module_name)
|
||||||
|
if transform_script_path not in self._transform_func_table:
|
||||||
|
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||||
|
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||||
|
if spec is None:
|
||||||
|
raise RuntimeError("spec is none! {}".format(module_name))
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
self._transform_func_table[transform_script_path] = {
|
||||||
|
"loader": loader,
|
||||||
|
"module": mod,
|
||||||
|
"last_m": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
transform_script_info = self._transform_func_table[transform_script_path]
|
||||||
|
curr_m_time = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||||
|
if curr_m_time > transform_script_info["last_m"]:
|
||||||
|
transform_script_info["last_m"] = curr_m_time
|
||||||
|
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||||
|
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||||
|
if transform_func is None:
|
||||||
|
|
||||||
|
def _empty_transform_func(data: CPM9GInputType, num_sample: int, r: random.Random):
|
||||||
|
raise NotImplementedError("Transform func for dataset {} not implemented".format(module_name))
|
||||||
|
|
||||||
|
return _empty_transform_func
|
||||||
|
else:
|
||||||
|
return transform_func
|
||||||
|
|
||||||
|
def build_instance(self, config: _MixedDatasetConfig):
|
||||||
|
_sample_weight = np.array(config["incontext_weight"], dtype=np.float32)
|
||||||
|
_sample_weight = _sample_weight / _sample_weight.sum()
|
||||||
|
num_incontext = np.random.choice(_sample_weight.shape[0], p=_sample_weight)
|
||||||
|
ds = config["dataset"]
|
||||||
|
transforms = config["transforms"]
|
||||||
|
if isinstance(transforms, str):
|
||||||
|
if not os.path.exists(transforms):
|
||||||
|
raise RuntimeError("transform script {} file not exists".format(transforms))
|
||||||
|
# load transform script
|
||||||
|
transform_func = self._ensure_transform_function(_dataset_identity(config), transforms)
|
||||||
|
seed = random.random()
|
||||||
|
|
||||||
|
def _transform(data: CPM9GInputType):
|
||||||
|
r = random.Random(seed)
|
||||||
|
return transform_func(data, num_incontext, r)
|
||||||
|
|
||||||
|
transform = _transform
|
||||||
|
elif len(transforms) == 0:
|
||||||
|
transform = None
|
||||||
|
else:
|
||||||
|
transform = transforms[np.random.choice(len(transforms))]
|
||||||
|
|
||||||
|
raw_data = {}
|
||||||
|
while True:
|
||||||
|
inp = ds.read()
|
||||||
|
inp = self.apply_transform(inp, transform)
|
||||||
|
# if None, skip this one
|
||||||
|
if inp is None:
|
||||||
|
continue
|
||||||
|
input_ids, context = self.data_to_id(inp)
|
||||||
|
if input_ids.shape[0] > self._clip_length:
|
||||||
|
continue # too long
|
||||||
|
input_ids = input_ids[: self._clip_length]
|
||||||
|
context = context[: self._clip_length]
|
||||||
|
raw_data["input"] = inp
|
||||||
|
raw_data["samples"] = []
|
||||||
|
break
|
||||||
|
|
||||||
|
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
if i == num_incontext or input_ids.shape[0] >= self._clip_length:
|
||||||
|
break # early break
|
||||||
|
sample = ds.read()
|
||||||
|
sample = self.apply_transform(sample, transform)
|
||||||
|
# if sample is None, skip this one
|
||||||
|
if sample is None:
|
||||||
|
continue
|
||||||
|
sample_input_ids, _ = self.data_to_id(sample)
|
||||||
|
if input_ids.shape[0] + sample_input_ids.shape[0] > self._clip_length:
|
||||||
|
break # too long, break
|
||||||
|
raw_data["samples"].append(sample)
|
||||||
|
input_ids = np.concatenate([input_ids, sample_input_ids], axis=0)
|
||||||
|
context = np.concatenate([context, np.ones(sample_input_ids.shape, dtype=np.int8)], axis=0)
|
||||||
|
sample_ids = np.concatenate([sample_ids, np.full(sample_input_ids.shape, i + 1, dtype=np.int32)], axis=0)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
input_ids,
|
||||||
|
context,
|
||||||
|
sample_ids,
|
||||||
|
raw_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_batch(self, force: bool = False, unilm=False) -> CPM9GBatch:
|
||||||
|
# pack batch
|
||||||
|
if len(self._inputs) < self._batch_size:
|
||||||
|
if not force:
|
||||||
|
raise RuntimeError("Batch insufficient")
|
||||||
|
batch_size = len(self._inputs)
|
||||||
|
else:
|
||||||
|
batch_size = self._batch_size
|
||||||
|
max_length = self._max_length # self._spans[0][-1] if self._unpad else self._max_length
|
||||||
|
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||||
|
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
||||||
|
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||||
|
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
||||||
|
|
||||||
|
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||||
|
length = np.zeros((batch_size,), dtype=np.int32)
|
||||||
|
task_ids = np.full((batch_size, max_length), -1, dtype=np.int32)
|
||||||
|
if self._spans[0][-1] != max_length:
|
||||||
|
cu_seqlens = np.array([0] + self._spans[0] + [max_length], dtype=np.int32)
|
||||||
|
else:
|
||||||
|
cu_seqlens = np.array([0] + self._spans[0], dtype=np.int32)
|
||||||
|
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
|
||||||
|
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||||
|
|
||||||
|
all_task_names: Set[str] = set()
|
||||||
|
for i in range(batch_size):
|
||||||
|
for task_name in self._task_ids[i]:
|
||||||
|
all_task_names.add(task_name)
|
||||||
|
task_names: List[str] = list(all_task_names)
|
||||||
|
task_name_to_id = {name: i for i, name in enumerate(task_names)}
|
||||||
|
|
||||||
|
raw_data_list: List[Any] = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
instance_length = self._inputs[i].shape[0]
|
||||||
|
inputs[i, :instance_length] = self._inputs[i]
|
||||||
|
sample_ids[i, :instance_length] = self._sample_ids[i]
|
||||||
|
if unilm:
|
||||||
|
context[i, :instance_length] = self._context[i]
|
||||||
|
|
||||||
|
span_begin = 0
|
||||||
|
for span_id, (span_end, task_name) in enumerate(zip(self._spans[i], self._task_ids[i])):
|
||||||
|
spans[i, span_begin:span_end] = span_id
|
||||||
|
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
|
||||||
|
task_ids[i, span_begin:span_end] = task_name_to_id[task_name]
|
||||||
|
span_begin = span_end
|
||||||
|
length[i] = instance_length
|
||||||
|
raw_data_list.extend(self._raw_data[i])
|
||||||
|
|
||||||
|
for j in range(instance_length):
|
||||||
|
if j > 1 and self._context[i][j] == 0:
|
||||||
|
if self._inputs[i][j] != self.tokenizer.bos_id and self._inputs[i][j - 1] != self.tokenizer.eos_id:
|
||||||
|
tgt[i, j - 1] = self._inputs[i][j]
|
||||||
|
|
||||||
|
self._inputs = self._inputs[batch_size:]
|
||||||
|
self._context = self._context[batch_size:]
|
||||||
|
self._sample_ids = self._sample_ids[batch_size:]
|
||||||
|
self._spans = self._spans[batch_size:]
|
||||||
|
self._task_ids = self._task_ids[batch_size:]
|
||||||
|
self._raw_data = self._raw_data[batch_size:]
|
||||||
|
return {
|
||||||
|
"inputs": inputs,
|
||||||
|
"length": length,
|
||||||
|
"context": context > 0,
|
||||||
|
"sample_ids": sample_ids,
|
||||||
|
"spans": spans,
|
||||||
|
"cu_seqlens": cu_seqlens,
|
||||||
|
"max_seqlen": max_seqlen,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"target": tgt,
|
||||||
|
"task_ids": task_ids,
|
||||||
|
"task_names": task_names,
|
||||||
|
"raw_data": raw_data_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_data(self, config: _MixedDatasetConfig) -> Optional[CPM9GBatch]:
|
||||||
|
(
|
||||||
|
input_ids,
|
||||||
|
context,
|
||||||
|
sample_ids,
|
||||||
|
raw_data,
|
||||||
|
) = self.build_instance(config)
|
||||||
|
|
||||||
|
# add to batch
|
||||||
|
best_fit: Union[None, int] = None
|
||||||
|
best_fit_space: Union[None, int] = None
|
||||||
|
for i in range(len(self._inputs)):
|
||||||
|
space = self._max_length - self._inputs[i].shape[0]
|
||||||
|
if input_ids.shape[0] <= space:
|
||||||
|
if best_fit_space is None:
|
||||||
|
best_fit = i
|
||||||
|
best_fit_space = space
|
||||||
|
elif best_fit_space > space:
|
||||||
|
best_fit = i
|
||||||
|
best_fit_space = space
|
||||||
|
if best_fit is None:
|
||||||
|
# add a new instance
|
||||||
|
self._inputs.append(input_ids)
|
||||||
|
self._context.append(context)
|
||||||
|
self._sample_ids.append(sample_ids)
|
||||||
|
self._spans.append([input_ids.shape[0]])
|
||||||
|
self._task_ids.append([config["task_name"]])
|
||||||
|
self._raw_data.append([raw_data])
|
||||||
|
else:
|
||||||
|
# add to existing instance
|
||||||
|
self._inputs[best_fit] = np.concatenate([self._inputs[best_fit], input_ids], axis=0)
|
||||||
|
self._context[best_fit] = np.concatenate([self._context[best_fit], context], axis=0)
|
||||||
|
self._sample_ids[best_fit] = np.concatenate([self._sample_ids[best_fit], sample_ids], axis=0)
|
||||||
|
self._spans[best_fit].append(self._inputs[best_fit].shape[0])
|
||||||
|
self._task_ids[best_fit].append(config["task_name"])
|
||||||
|
self._raw_data[best_fit].append(raw_data)
|
||||||
|
|
||||||
|
if len(self._inputs) > self._batch_size:
|
||||||
|
return self.pack_batch()
|
||||||
|
else:
|
||||||
|
return None # not ready
|
||||||
|
|
||||||
|
|
||||||
|
class _MixedDatasetConfigMananger:
|
||||||
|
def __init__(self, config_path: str) -> None:
|
||||||
|
self._config_path: str = config_path
|
||||||
|
self._config: Union[List[_MixedDatasetConfig], None] = None
|
||||||
|
self._last_m = 0
|
||||||
|
|
||||||
|
def changed(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
m_time = os.stat(self._config_path).st_mtime
|
||||||
|
if m_time > self._last_m:
|
||||||
|
# try to load new config
|
||||||
|
try:
|
||||||
|
self._config = json.load(open(self._config_path, "r", encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
# failed to load config
|
||||||
|
return False
|
||||||
|
# new config loaded
|
||||||
|
self._last_m = m_time
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
print("Error: reading info list! {}".format(self._config_path))
|
||||||
|
time.sleep(30)
|
||||||
|
|
||||||
|
def get_config(self) -> List[_MixedDatasetConfig]:
|
||||||
|
if self._config is None:
|
||||||
|
if not self.changed():
|
||||||
|
raise RuntimeError("Failed to load config")
|
||||||
|
if self._config is None:
|
||||||
|
raise RuntimeError("Failed to load config")
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
|
||||||
|
def _mixed_dataset_process(
|
||||||
|
config_path: str,
|
||||||
|
q_cmd: multiprocessing.Queue,
|
||||||
|
q_cmd_out: multiprocessing.Queue,
|
||||||
|
q_data: multiprocessing.Queue,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
packer: _MixedDatasetBatchPacker,
|
||||||
|
max_repeat_times: int = None,
|
||||||
|
random_state=None,
|
||||||
|
):
|
||||||
|
for _ in range(rank):
|
||||||
|
np.random.random()
|
||||||
|
random.setstate(random_state)
|
||||||
|
# ignore SIGINT
|
||||||
|
import signal
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||||
|
config_base_path = os.path.dirname(os.path.abspath(config_path))
|
||||||
|
|
||||||
|
# fix libgomp threading deadlock after fork(), use single-thread in data process.
|
||||||
|
# REF: https://github.com/pytorch/pytorch/issues/17199
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
def _convert_to_abs_path(transform_path: str):
|
||||||
|
if transform_path.startswith("/"):
|
||||||
|
return transform_path
|
||||||
|
else:
|
||||||
|
return os.path.join(config_base_path, transform_path)
|
||||||
|
|
||||||
|
def _build_sample_weights(config: List[_MixedDatasetConfig]):
|
||||||
|
if len(config) == 0:
|
||||||
|
return np.array([], dtype=np.float32)
|
||||||
|
weights = [c["weight"] * c["lines"] for c in config]
|
||||||
|
weights = np.array(weights, dtype=np.float32)
|
||||||
|
sm_weight = weights.sum()
|
||||||
|
if sm_weight > 0:
|
||||||
|
weights = weights / sm_weight
|
||||||
|
return weights
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Empty datasets")
|
||||||
|
|
||||||
|
cfg_mgr = _MixedDatasetConfigMananger(config_path)
|
||||||
|
config = cfg_mgr.get_config()
|
||||||
|
|
||||||
|
for c in config:
|
||||||
|
ds = DistributedDataset(
|
||||||
|
_convert_to_abs_path(c["path"]),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
max_repeat_times=max_repeat_times,
|
||||||
|
)
|
||||||
|
|
||||||
|
c["lines"] = ds._nlines
|
||||||
|
c["dataset"] = ds
|
||||||
|
if "weight" not in c:
|
||||||
|
c["weight"] = 1.0
|
||||||
|
if "transforms" not in c:
|
||||||
|
c["transforms"] = []
|
||||||
|
elif isinstance(c["transforms"], str):
|
||||||
|
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
||||||
|
if "incontext_weight" not in c:
|
||||||
|
c["incontext_weight"] = [1.0]
|
||||||
|
|
||||||
|
weights = _build_sample_weights(config)
|
||||||
|
|
||||||
|
should_stop = False
|
||||||
|
should_start = False
|
||||||
|
|
||||||
|
while not should_stop:
|
||||||
|
# update config first
|
||||||
|
if cfg_mgr.changed():
|
||||||
|
path_ds_map: Dict[str, _MixedDatasetConfig] = {}
|
||||||
|
nw_path_set: Set[str] = set()
|
||||||
|
|
||||||
|
# load new config
|
||||||
|
nw_config = cfg_mgr.get_config()
|
||||||
|
|
||||||
|
# build path -> dataset map
|
||||||
|
for c in config:
|
||||||
|
path_ds_map[_dataset_identity(c)] = c
|
||||||
|
|
||||||
|
# add new datasets
|
||||||
|
for c in nw_config:
|
||||||
|
if _dataset_identity(c) in path_ds_map:
|
||||||
|
# update values only
|
||||||
|
if "weight" in c:
|
||||||
|
path_ds_map[_dataset_identity(c)]["weight"] = c["weight"]
|
||||||
|
if "transform" in c:
|
||||||
|
if isinstance(c["transforms"], str):
|
||||||
|
path_ds_map[_dataset_identity(c)]["transforms"] = _convert_to_abs_path(c["transforms"])
|
||||||
|
else:
|
||||||
|
path_ds_map[_dataset_identity(c)]["transforms"] = c["transforms"]
|
||||||
|
if "incontext_weight" in c:
|
||||||
|
path_ds_map[_dataset_identity(c)]["incontext_weight"] = c["incontext_weight"]
|
||||||
|
else:
|
||||||
|
# new dataset
|
||||||
|
ds = DistributedDataset(
|
||||||
|
_convert_to_abs_path(c["path"]),
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
max_repeat_times=max_repeat_times,
|
||||||
|
)
|
||||||
|
c["lines"] = ds._nlines
|
||||||
|
c["dataset"] = ds
|
||||||
|
if "weight" not in c:
|
||||||
|
c["weight"] = 1.0
|
||||||
|
if "transforms" not in c:
|
||||||
|
c["transforms"] = []
|
||||||
|
elif isinstance(c["transforms"], str):
|
||||||
|
c["transforms"] = _convert_to_abs_path(c["transforms"])
|
||||||
|
if "incontext_weight" not in c:
|
||||||
|
c["incontext_weight"] = [1.0]
|
||||||
|
path_ds_map[_dataset_identity(c)] = c
|
||||||
|
nw_path_set.add(_dataset_identity(c))
|
||||||
|
|
||||||
|
# remove unused datasets
|
||||||
|
for c in config:
|
||||||
|
if _dataset_identity(c) not in nw_path_set:
|
||||||
|
del path_ds_map[_dataset_identity(c)]
|
||||||
|
|
||||||
|
config: List[_MixedDatasetConfig] = []
|
||||||
|
for c in nw_config:
|
||||||
|
config.append(path_ds_map[_dataset_identity(c)])
|
||||||
|
del path_ds_map
|
||||||
|
del nw_path_set
|
||||||
|
del nw_config
|
||||||
|
|
||||||
|
weights = _build_sample_weights(config)
|
||||||
|
|
||||||
|
# get cmds
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
cmd = q_cmd.get_nowait()
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
if cmd == "stop":
|
||||||
|
should_stop = True
|
||||||
|
q_cmd_out.put(True)
|
||||||
|
break
|
||||||
|
elif cmd == "state_dict":
|
||||||
|
ret = OrderedDict()
|
||||||
|
for c in config:
|
||||||
|
ds_name = _dataset_identity(c)
|
||||||
|
ret[ds_name] = c["dataset"]._state_dict()
|
||||||
|
q_cmd_out.put(ret)
|
||||||
|
elif cmd == "load_state_dict":
|
||||||
|
state_dict = q_cmd.get()
|
||||||
|
missing = []
|
||||||
|
for idx, c in enumerate(config):
|
||||||
|
ds_name = _dataset_identity(c)
|
||||||
|
print(f"loading {idx}/{len(config)} {ds_name}")
|
||||||
|
if ds_name in state_dict:
|
||||||
|
c["dataset"].load_state_dict(state_dict[ds_name], strict=False)
|
||||||
|
else:
|
||||||
|
# new dataset
|
||||||
|
missing.append(ds_name)
|
||||||
|
q_cmd_out.put(missing)
|
||||||
|
elif cmd == "start":
|
||||||
|
should_start = True
|
||||||
|
q_cmd_out.put(True)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown command: {}".format(cmd))
|
||||||
|
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not should_start:
|
||||||
|
# wait for start cmd
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(config) == 0:
|
||||||
|
# no dataset available
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if q_data.full():
|
||||||
|
# queue full
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# sample a dataset
|
||||||
|
ds_id: int = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ds_id = np.random.choice(weights.shape[0], p=weights)
|
||||||
|
if config[ds_id]["dataset"]._nlines != config[ds_id]["lines"]:
|
||||||
|
# dataset size changed
|
||||||
|
for c in config:
|
||||||
|
c["lines"] = c["dataset"]._nlines
|
||||||
|
weights = _build_sample_weights(config)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch = packer.add_data(config[ds_id])
|
||||||
|
if batch is not None:
|
||||||
|
# new batch comming
|
||||||
|
q_data.put(batch)
|
||||||
|
|
||||||
|
# clean queue
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
q_data.get_nowait()
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class MixedDataset:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_path: str,
|
||||||
|
batch_size: int,
|
||||||
|
max_length: int,
|
||||||
|
tokenizer: CPM9GTokenizer,
|
||||||
|
unpad: bool = False,
|
||||||
|
max_repeat_times: int = None,
|
||||||
|
) -> None:
|
||||||
|
self._q_cmd = multiprocessing.Queue()
|
||||||
|
self._q_cmd_out = multiprocessing.Queue()
|
||||||
|
self._q_data = multiprocessing.Queue(maxsize=2)
|
||||||
|
self._packer = _MixedDatasetBatchPacker(batch_size, max_length, tokenizer, unpad)
|
||||||
|
self._p = multiprocessing.Process(
|
||||||
|
target=_mixed_dataset_process,
|
||||||
|
args=(
|
||||||
|
config_path,
|
||||||
|
self._q_cmd,
|
||||||
|
self._q_cmd_out,
|
||||||
|
self._q_data,
|
||||||
|
bmt.rank(),
|
||||||
|
bmt.world_size(),
|
||||||
|
self._packer,
|
||||||
|
max_repeat_times,
|
||||||
|
random.getstate(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._p.start()
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if not self._closed:
|
||||||
|
self._closed = True
|
||||||
|
self._q_cmd.put("stop")
|
||||||
|
assert self._q_cmd_out.get(), "Failed to stop process"
|
||||||
|
self._p.join()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self):
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._q_cmd.put("start")
|
||||||
|
return self._q_cmd_out.get()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
self._q_cmd.put("state_dict")
|
||||||
|
states = self._q_cmd_out.get()
|
||||||
|
if not isinstance(states, OrderedDict):
|
||||||
|
raise RuntimeError("Invalid state dict {}".format(states))
|
||||||
|
if bmt.world_size() == 1:
|
||||||
|
for val in states.values():
|
||||||
|
val["states"].unsqueeze_(0)
|
||||||
|
val["block"].unsqueeze_(0)
|
||||||
|
return states
|
||||||
|
|
||||||
|
ret = OrderedDict()
|
||||||
|
for k, v in states.items():
|
||||||
|
num_unused_block = v["states"].size(0)
|
||||||
|
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
||||||
|
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
|
||||||
|
if max_unused_blocks == 0:
|
||||||
|
max_unused_blocks = 1
|
||||||
|
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
||||||
|
gpu_states[:num_unused_block] = v["states"].cuda()
|
||||||
|
|
||||||
|
gpu_block = v["block"].cuda()
|
||||||
|
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
|
||||||
|
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
|
||||||
|
ret[k] = {"states": global_states, "block": global_block}
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def load_state_dict(self, data: OrderedDict, strict: bool = False):
|
||||||
|
self._q_cmd.put("load_state_dict")
|
||||||
|
self._q_cmd.put(data)
|
||||||
|
missing = self._q_cmd_out.get()
|
||||||
|
if strict:
|
||||||
|
if len(missing) > 0:
|
||||||
|
raise RuntimeError("Missing dataset state: {}".format(missing))
|
||||||
|
return missing
|
||||||
|
|
||||||
|
def get(self) -> CPM9GBatch:
|
||||||
|
ret: CPM9GBatch = self._q_data.get(timeout=300) # type: ignore
|
||||||
|
if not isinstance(ret, dict):
|
||||||
|
raise RuntimeError("Invalid data {}".format(ret))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while True:
|
||||||
|
yield self.get()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if not self.closed:
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
File diff suppressed because it is too large
Load Diff
|
@ -4,8 +4,7 @@ from .distributed_dataset import SimpleDataset
|
||||||
from .indexed_dataset import IndexedDataset
|
from .indexed_dataset import IndexedDataset
|
||||||
from .indexed_dataset import IndexedDatasetBuilder
|
from .indexed_dataset import IndexedDatasetBuilder
|
||||||
from .indexed_dataset import PrefetchDecodeDataset
|
from .indexed_dataset import PrefetchDecodeDataset
|
||||||
|
from .list_dataset import ListDataset
|
||||||
# from .list_dataset import ListDataset
|
|
||||||
from .utils import compact_dataset
|
from .utils import compact_dataset
|
||||||
from .utils import CudaPrefetcher
|
from .utils import CudaPrefetcher
|
||||||
from .utils import mask_dataset
|
from .utils import mask_dataset
|
|
@ -1,18 +1,3 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
@ -296,6 +281,7 @@ class DistributedDataset:
|
||||||
info: List[FileInfo] = []
|
info: List[FileInfo] = []
|
||||||
if os.path.exists(meta_path):
|
if os.path.exists(meta_path):
|
||||||
info = _read_info_list(meta_path)
|
info = _read_info_list(meta_path)
|
||||||
|
|
||||||
old_len = len(self._file_info)
|
old_len = len(self._file_info)
|
||||||
if old_len > len(info):
|
if old_len > len(info):
|
||||||
raise RuntimeError("Dataset meta file: changed unexpectly")
|
raise RuntimeError("Dataset meta file: changed unexpectly")
|
||||||
|
@ -457,11 +443,7 @@ class DistributedDataset:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self._world_size > 1:
|
if self._world_size > 1:
|
||||||
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
||||||
max_unused_blocks = (
|
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
|
||||||
bmt.distributed.all_reduce(gpu_num_unused_block, op="max", comm=bmt.config["tp_zero_comm"])
|
|
||||||
.cpu()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
||||||
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
|
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
|
||||||
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
|
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
|
||||||
|
@ -470,15 +452,9 @@ class DistributedDataset:
|
||||||
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
).cuda()
|
).cuda()
|
||||||
global_states = bmt.distributed.all_gather(
|
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
|
||||||
gpu_states, comm=bmt.config["tp_zero_comm"]
|
global_offset = bmt.distributed.all_gather(gpu_offset).cpu() # (world_size, max_unused_blocks)
|
||||||
).cpu() # (world_size, max_unused_blocks)
|
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
|
||||||
global_offset = bmt.distributed.all_gather(
|
|
||||||
gpu_offset, comm=bmt.config["tp_zero_comm"]
|
|
||||||
).cpu() # (world_size, max_unused_blocks)
|
|
||||||
global_block = bmt.distributed.all_gather(
|
|
||||||
gpu_block, comm=bmt.config["tp_zero_comm"]
|
|
||||||
).cpu() # (world_size, 4)
|
|
||||||
return {"states": global_states, "offset": global_offset, "block": global_block}
|
return {"states": global_states, "offset": global_offset, "block": global_block}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
|
@ -1,41 +1,13 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
|
||||||
#
|
|
||||||
# @author: ouzebin <ouzebin@zhihu.com>
|
|
||||||
# @date: 2023/09/27
|
|
||||||
|
|
||||||
"""
|
|
||||||
使用 IndexedDataset 前需按指定格式构建或者转换已有数据集
|
|
||||||
数据集文件结构:
|
|
||||||
- <dataset name>
|
|
||||||
- data.jsonl # jsonl 格式的数据,每一行一条样本
|
|
||||||
- index # 记录每一行 json 数据的起始 byte-offset
|
|
||||||
|
|
||||||
从头构建:直接使用 IndexedDatasetBuilder 这个 context manager
|
|
||||||
>>> with IndexedDatasetBuilder("swear", overwrite=True) as builder:
|
|
||||||
>>> for data in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
|
|
||||||
>>> builder.put(data)
|
|
||||||
转换:
|
|
||||||
从 fm9g distributed_dataset 转换:使用 `fm9g.dataset.tools.distributed_to_indexed`
|
|
||||||
$ python -m fm9g.dataset.tools.distributed_to_indexed -i <原数据集文件夹> -o <新数据集文件夹>
|
|
||||||
已有 jsonl 数据:使用 `fm9g.dataset.tools.jsonl_to_index` 构建 index 文件。需提前先把 jsonl 文件命名为
|
|
||||||
$ python -m fm9g.dataset.tools.jsonl_to_index -p <数据集文件夹路径>
|
|
||||||
"""
|
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import queue
|
import queue
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import bmtrain as bmt
|
import bmtrain as bmt
|
||||||
import h5py
|
|
||||||
import numpy
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import msgspec
|
import msgspec
|
||||||
|
@ -50,83 +22,13 @@ except ModuleNotFoundError:
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from fm9g.utils.bitset import BitSet
|
from .utils import Range
|
||||||
from fm9g.utils.bitset import bitset_diff
|
|
||||||
|
|
||||||
print_lock = threading.Lock()
|
print_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def random_range(start, stop=None, step=None):
|
|
||||||
"""
|
|
||||||
Generator of non-repeated random permutation with the same inteface of python
|
|
||||||
`range`. Obtained from https://stackoverflow.com/a/53551417
|
|
||||||
The random.shuffle(list) and random.sample(list, len(list)) require
|
|
||||||
materialize the lists, which result in a long initalization period.
|
|
||||||
"""
|
|
||||||
if stop is None:
|
|
||||||
start, stop = 0, start
|
|
||||||
if step is None:
|
|
||||||
step = 1
|
|
||||||
# Use a mapping to convert a standard range into the desired range.
|
|
||||||
mapping = lambda i: (i * step) + start
|
|
||||||
# Compute the number of numbers in this range.
|
|
||||||
maximum = int(math.ceil((stop - start) / step))
|
|
||||||
if maximum == 0:
|
|
||||||
# early return with empty range
|
|
||||||
yield from ()
|
|
||||||
return
|
|
||||||
# Seed range with a random integer.
|
|
||||||
value = random.randint(0, maximum)
|
|
||||||
# Construct an offset, multiplier, and modulus for a linear
|
|
||||||
# congruential generator. These generators are cyclic and
|
|
||||||
# non-repeating when they maintain the properties:
|
|
||||||
#
|
|
||||||
# 1) "modulus" and "offset" are relatively prime.
|
|
||||||
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
|
|
||||||
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
|
|
||||||
|
|
||||||
# Pick a random odd-valued offset.
|
|
||||||
offset = random.randint(0, maximum) * 2 + 1
|
|
||||||
# Pick a multiplier 1 greater than a multiple of 4.
|
|
||||||
multiplier = 4 * (maximum // 4) + 1
|
|
||||||
# Pick a modulus just big enough to generate all numbers (power of 2).
|
|
||||||
modulus = int(2 ** math.ceil(math.log2(maximum)))
|
|
||||||
# Track how many random numbers have been returned.
|
|
||||||
found = 0
|
|
||||||
while found < maximum:
|
|
||||||
# If this is a valid value, yield it in generator fashion.
|
|
||||||
if value < maximum:
|
|
||||||
found += 1
|
|
||||||
yield mapping(value)
|
|
||||||
# Calculate the next value in the sequence.
|
|
||||||
value = (value * multiplier + offset) % modulus
|
|
||||||
|
|
||||||
|
|
||||||
class Range(object):
|
|
||||||
def __init__(self, start, stop, step):
|
|
||||||
self.start = start
|
|
||||||
self.stop = stop
|
|
||||||
self.step = step
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"Range({self.start}, {self.stop}, {self.step})"
|
|
||||||
|
|
||||||
def iterate(self):
|
|
||||||
yield from range(self.start, self.stop, self.step)
|
|
||||||
|
|
||||||
def list(self):
|
|
||||||
return list(range(self.start, self.stop, self.step))
|
|
||||||
|
|
||||||
def subrange(self, split, nsplits):
|
|
||||||
# strided spliting range params
|
|
||||||
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
|
|
||||||
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
|
|
||||||
|
|
||||||
def random_iterate(self):
|
|
||||||
yield from random_range(self.start, self.stop, self.step)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_print(*args, **kargs):
|
def safe_print(*args, **kargs):
|
||||||
if "flush" in kargs:
|
if "flush" in kargs:
|
||||||
flush = kargs["flush"]
|
flush = kargs["flush"]
|
||||||
|
@ -138,15 +40,12 @@ def safe_print(*args, **kargs):
|
||||||
|
|
||||||
|
|
||||||
def concurrent_info():
|
def concurrent_info():
|
||||||
# world_size, rank = bmt.world_size(), bmt.rank()
|
world_size, rank = bmt.world_size(), bmt.rank()
|
||||||
world_size = bmt.config["world_size"] // bmt.config["tp_size"]
|
|
||||||
rank = bmt.config["topology"].tp_idx
|
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
if worker_info is None:
|
if worker_info is None:
|
||||||
nworkers, worker_id = 1, 1
|
nworkers, worker_id = 1, 1
|
||||||
else:
|
else:
|
||||||
nworkers, worker_id = worker_info.num_workers, worker_info.id
|
nworkers, worker_id = worker_info.num_workers, worker_info.id
|
||||||
# print("concurrent_info: (world_size, rank, nworkers, worker_id): {}".format((world_size, rank, nworkers, worker_id)))
|
|
||||||
return world_size, rank, nworkers, worker_id
|
return world_size, rank, nworkers, worker_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,45 +56,14 @@ class IndexedDataset(Dataset):
|
||||||
self.max_retry = max_retry
|
self.max_retry = max_retry
|
||||||
self.retry_sleep = retry_sleep
|
self.retry_sleep = retry_sleep
|
||||||
self.bounds = None
|
self.bounds = None
|
||||||
self.h5file = None
|
|
||||||
self.build_index()
|
self.build_index()
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.bounds[-1]
|
return self.bounds[-1]
|
||||||
|
|
||||||
def _build_index_h5(self):
|
|
||||||
index_path = os.path.join(self.path, "index.h5")
|
|
||||||
if os.path.getsize(index_path) > 104857600:
|
|
||||||
self.h5file = h5py.File(os.path.join(self.path, "index.h5"), "r")
|
|
||||||
self.bounds = self.h5file["index"]
|
|
||||||
else:
|
|
||||||
# only load index into memory when it is small (< 100 Mb)
|
|
||||||
# to avoid keeping to many file handlers
|
|
||||||
self.h5file = None
|
|
||||||
with h5py.File(index_path, "r") as hf:
|
|
||||||
self.bounds = np.array(hf["index"])
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.h5file is not None:
|
|
||||||
self.h5file.close()
|
|
||||||
|
|
||||||
def build_index(self):
|
def build_index(self):
|
||||||
s = time.time()
|
|
||||||
|
|
||||||
txt_size = os.path.getsize(os.path.join(self.path, "index"))
|
|
||||||
if txt_size > 0.5 * 1024**3 and os.path.exists(os.path.join(self.path, "index.h5")):
|
|
||||||
source = "h5"
|
|
||||||
self._build_index_h5()
|
|
||||||
else:
|
|
||||||
source = "txt"
|
|
||||||
self._build_index_txt()
|
|
||||||
e = time.time()
|
|
||||||
bmt.print_rank("build_index_{} from {}, using {:.2f}s".format(source, self.path, e - s))
|
|
||||||
|
|
||||||
def _build_index_txt(self):
|
|
||||||
with open(os.path.join(self.path, "index"), "r") as fin:
|
with open(os.path.join(self.path, "index"), "r") as fin:
|
||||||
self.bounds = [int(line) for line in fin]
|
self.bounds = [int(line) for line in fin]
|
||||||
self.nlines = len(self.bounds)
|
|
||||||
|
|
||||||
def safe_read(self, i_or_s, offset, size):
|
def safe_read(self, i_or_s, offset, size):
|
||||||
for retry in itertools.count():
|
for retry in itertools.count():
|
||||||
|
@ -270,10 +138,39 @@ class IndexedDataset(Dataset):
|
||||||
class PrefetchDecodeDataset(IndexedDataset):
|
class PrefetchDecodeDataset(IndexedDataset):
|
||||||
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
|
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
|
||||||
# Add safe decoding in iterator
|
# Add safe decoding in iterator
|
||||||
def __init__(self, *args, decode=json_decode, allow_repeat=False, **kargs):
|
def __init__(self, *args, decode=json_decode, **kargs):
|
||||||
super().__init__(*args, **kargs)
|
super().__init__(*args, **kargs)
|
||||||
self.decode = decode
|
self.decode = decode
|
||||||
self.allow_repeat = allow_repeat
|
self.lock = threading.Lock()
|
||||||
|
self.prev_used = set() # store previously used index in the checkpoint
|
||||||
|
self.used = set() # track locally used index
|
||||||
|
|
||||||
|
def state_dict(self, gathered=True):
|
||||||
|
if not self.prev_used and not self.used:
|
||||||
|
return {"prev_used": set()}
|
||||||
|
if gathered:
|
||||||
|
used = torch.tensor(list(self.used)).cuda()
|
||||||
|
size = torch.tensor(used.numel()).cuda()
|
||||||
|
max_size = bmt.distributed.all_reduce(size, op="max")
|
||||||
|
# allgather requires tensors having the same size
|
||||||
|
used = torch.cat([used, torch.full((max_size - size,), -100, device=used.device)], dim=-1)
|
||||||
|
all_used = bmt.distributed.all_gather(used).unique()
|
||||||
|
all_used = set(all_used.tolist())
|
||||||
|
if -100 in all_used:
|
||||||
|
all_used.remove(-100) # remove the padding value
|
||||||
|
all_used.union(self.prev_used)
|
||||||
|
return {"prev_used": all_used}
|
||||||
|
else:
|
||||||
|
return {"prev_used": self.prev_used.union(self.used)}
|
||||||
|
|
||||||
|
def load_state_dict(self, state):
|
||||||
|
with self.lock:
|
||||||
|
self.used = state.get("prev_used", set())
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
with self.lock:
|
||||||
|
self.used = set()
|
||||||
|
self.prev_used = set()
|
||||||
|
|
||||||
def safe_decode(self, i, raw):
|
def safe_decode(self, i, raw):
|
||||||
if raw is None:
|
if raw is None:
|
||||||
|
@ -294,23 +191,19 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
else:
|
else:
|
||||||
return self.safe_decode(key, raw)
|
return self.safe_decode(key, raw)
|
||||||
|
|
||||||
def loader(self, q, lid, keys, stop, used=None):
|
def loader(self, q, lid, keys, stop):
|
||||||
# concurrent prefetching worker
|
# concurrent prefetching worker
|
||||||
if used is None:
|
|
||||||
used = BitSet()
|
|
||||||
try:
|
try:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if stop.is_set():
|
if stop.is_set():
|
||||||
break
|
break
|
||||||
# key is either a slice or an integer index
|
# key is either a slice or an integer index
|
||||||
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
|
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
|
||||||
unused = bitset_diff(set(index), used)
|
with self.lock:
|
||||||
|
unused = set(index) - self.used - self.prev_used
|
||||||
if not unused:
|
if not unused:
|
||||||
# skip used slice / item
|
# skip used slice / item
|
||||||
continue
|
continue
|
||||||
if not q.empty():
|
|
||||||
# avoid breaking the distributed file system with large io load
|
|
||||||
time.sleep(random.random() * 2)
|
|
||||||
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
|
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
|
||||||
raw = super().__getitem__(key)
|
raw = super().__getitem__(key)
|
||||||
if raw is None:
|
if raw is None:
|
||||||
|
@ -324,14 +217,14 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
# signaling the end of iteration to the main thread
|
# signaling the end of iteration to the main thread
|
||||||
q.put(StopIteration(lid))
|
q.put(StopIteration(lid))
|
||||||
|
|
||||||
def _iterate(self, key_groups, nprefetch=1000, used=None):
|
def _iterate(self, key_groups, nprefetch=1000):
|
||||||
# helper function for concurrent prefetching
|
# helper function for concurrent prefetching
|
||||||
q = queue.Queue(maxsize=nprefetch)
|
q = queue.Queue(maxsize=nprefetch)
|
||||||
stop = threading.Event()
|
stop = threading.Event()
|
||||||
alive = set()
|
alive = set()
|
||||||
try:
|
try:
|
||||||
for lid, keys in enumerate(key_groups):
|
for lid, keys in enumerate(key_groups):
|
||||||
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop, used), daemon=True)
|
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop), daemon=True)
|
||||||
loader.start()
|
loader.start()
|
||||||
alive.add(lid)
|
alive.add(lid)
|
||||||
while True:
|
while True:
|
||||||
|
@ -343,7 +236,7 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# new item will be put later, wait for a while
|
# new item will be put later, wait for a while
|
||||||
time.sleep(0.1)
|
time.sleep(0.3)
|
||||||
continue
|
continue
|
||||||
if isinstance(item, StopIteration):
|
if isinstance(item, StopIteration):
|
||||||
alive.remove(item.value)
|
alive.remove(item.value)
|
||||||
|
@ -352,13 +245,16 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
data = self.safe_decode(i, raw)
|
data = self.safe_decode(i, raw)
|
||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
yield i, data
|
self.used.add(i)
|
||||||
|
yield data
|
||||||
|
# automatically reset states with graceful ends.
|
||||||
|
self.reset()
|
||||||
finally:
|
finally:
|
||||||
# ask daemon loaders to stop
|
# ask daemon loaders to stop
|
||||||
stop.set()
|
stop.set()
|
||||||
|
|
||||||
def iterate(self, nthreads=3, prefetch_sample=100, used=None, process_group=None):
|
def iterate(self, nthreads=3, prefetch_sample=100):
|
||||||
world_size, rank, nworkers, worker_id = concurrent_info(process_group)
|
world_size, rank, nworkers, worker_id = concurrent_info()
|
||||||
nloaders = world_size * nworkers * nthreads
|
nloaders = world_size * nworkers * nthreads
|
||||||
if len(self) < nloaders:
|
if len(self) < nloaders:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -373,61 +269,51 @@ class PrefetchDecodeDataset(IndexedDataset):
|
||||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||||
# split index among multi-threaded loaders
|
# split index among multi-threaded loaders
|
||||||
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
|
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
|
||||||
return self._iterate(id_groups, nprefetch=prefetch_sample, used=used)
|
for data in self._iterate(id_groups, nprefetch=prefetch_sample):
|
||||||
|
yield data
|
||||||
|
|
||||||
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=500, used=None):
|
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=1000):
|
||||||
world_size, rank, nworkers, worker_id = concurrent_info()
|
world_size, rank, nworkers, worker_id = concurrent_info()
|
||||||
nloaders = world_size * nworkers * nthreads
|
nloaders = world_size * nworkers * nthreads
|
||||||
if len(self) < nloaders:
|
if len(self) < nloaders:
|
||||||
if not self.allow_repeat:
|
raise ValueError(
|
||||||
raise ValueError(
|
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||||||
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
f"please constrain either "
|
||||||
f"please constrain either "
|
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||||||
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
)
|
||||||
)
|
nslices = int(math.ceil(len(self) / slice_size))
|
||||||
else:
|
|
||||||
duplicated_factor = math.ceil(nloaders / len(self))
|
|
||||||
# In this case, slice size is 1
|
|
||||||
r = Range(0, len(self), 1)
|
|
||||||
# split index among grouped multi-gpu workers
|
|
||||||
r = r.subrange(split=rank // duplicated_factor, nsplits=math.ceil(world_size / duplicated_factor))
|
|
||||||
# # split index among multi-threaded loaders
|
|
||||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
|
||||||
else:
|
|
||||||
nslices = int(math.ceil(len(self) / slice_size))
|
|
||||||
|
|
||||||
if nslices < nloaders:
|
if nslices < nloaders:
|
||||||
safe_print(
|
safe_print(
|
||||||
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
|
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
|
||||||
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
|
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
|
||||||
)
|
)
|
||||||
slice_size = len(self) // nloaders
|
slice_size = len(self) // nloaders
|
||||||
|
|
||||||
# we only iteratre through start ids as they uniquely mark each slice
|
# we only iteratre through start ids as they uniquely mark each slice
|
||||||
r = Range(0, len(self), slice_size)
|
r = Range(0, len(self), slice_size)
|
||||||
# split index among multi-gpu workers
|
# split index among multi-gpu workers
|
||||||
r = r.subrange(split=rank, nsplits=world_size)
|
r = r.subrange(split=rank, nsplits=world_size)
|
||||||
# split index among multi-process dataloader workers
|
# split index among multi-process dataloader workers
|
||||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||||
# split index among multi-threaded loaders
|
# split index among multi-threaded loaders
|
||||||
slice_groups = [
|
slice_groups = [
|
||||||
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
|
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
|
||||||
]
|
]
|
||||||
return self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size, used=used)
|
for data in self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size):
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
class IndexedDatasetBuilder:
|
class IndexedDatasetBuilder:
|
||||||
def __init__(self, path, overwrite=False):
|
def __init__(self, path, overwrite=False):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.index_path = os.path.join(self.path, "index.h5")
|
self.index_path = os.path.join(self.path, "index")
|
||||||
self.index_path_txt = os.path.join(self.path, "index")
|
|
||||||
self.data_path = os.path.join(self.path, "data.jsonl")
|
self.data_path = os.path.join(self.path, "data.jsonl")
|
||||||
if not overwrite:
|
if not overwrite:
|
||||||
assert not os.path.exists(self.data_path)
|
assert not os.path.exists(self.data_path)
|
||||||
assert not os.path.exists(self.index_path)
|
assert not os.path.exists(self.index_path)
|
||||||
assert not os.path.exists(self.index_path_txt)
|
|
||||||
self.fout = None
|
self.fout = None
|
||||||
self.bounds = []
|
self.starts = []
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -436,17 +322,15 @@ class IndexedDatasetBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.bounds.append(self.offset)
|
self.starts.append(self.offset)
|
||||||
with h5py.File(os.path.join(self.index_path), "w") as hf:
|
with open(self.index_path, "w") as fout:
|
||||||
hf.create_dataset("index", data=self.bounds)
|
for s in self.starts:
|
||||||
with open(self.index_path_txt, "w") as fout_txt:
|
fout.write(f"{s}\n")
|
||||||
for s in self.bounds:
|
|
||||||
fout_txt.write(f"{s}\n")
|
|
||||||
self.fout.close()
|
self.fout.close()
|
||||||
|
|
||||||
def put(self, data: dict):
|
def put(self, data: dict):
|
||||||
s = json_encode(data) + b"\n"
|
s = json_encode(data) + b"\n"
|
||||||
self.bounds.append(self.offset)
|
self.starts.append(self.offset)
|
||||||
self.offset += len(s)
|
self.offset += len(s)
|
||||||
self.fout.write(s)
|
self.fout.write(s)
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ListDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
同时支持 map-style 和 iterable-style
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, data_list: List, distributed: bool = False, shuffle: bool = True, infinite: bool = False
|
||||||
|
) -> None:
|
||||||
|
super(ListDataset, self).__init__()
|
||||||
|
if distributed:
|
||||||
|
rank = bmt.rank()
|
||||||
|
world_size = bmt.world_size()
|
||||||
|
self.data_list = list(itertools.islice(data_list, rank, None, world_size))
|
||||||
|
else:
|
||||||
|
self.data_list = data_list
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.infinite = infinite
|
||||||
|
self.idx = 0
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
self._shuffle()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.idx >= len(self):
|
||||||
|
if self.infinite:
|
||||||
|
if self.shuffle:
|
||||||
|
self._shuffle()
|
||||||
|
self.idx = 0
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
data = self.data_list[self.idx]
|
||||||
|
self.idx += 1
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data_list[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_list)
|
||||||
|
|
||||||
|
def _shuffle(self):
|
||||||
|
random.shuffle(self.data_list)
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
return self.__next__()
|
|
@ -1,18 +1,3 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
|
|
@ -1,23 +1,14 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
|
||||||
#
|
|
||||||
# @author: ouzebin <ouzebin@zhihu.com>
|
|
||||||
# @date: 2023/07/27
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from fm9g.dataset import SimpleDataset
|
from cpm.dataset import SimpleDataset
|
||||||
from fm9g.dataset.indexed_dataset import IndexedDatasetBuilder
|
from cpm.dataset.indexed_dataset import IndexedDatasetBuilder
|
||||||
|
|
||||||
|
|
||||||
def convert_fm9g_data(fm9g_path, out_path):
|
def convert_cpm_data(cpm_path, out_path):
|
||||||
dataset = SimpleDataset(fm9g_path, shuffle=False)
|
dataset = SimpleDataset(cpm_path, shuffle=False)
|
||||||
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
|
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
|
||||||
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
|
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
|
||||||
builder.put(dataset.read())
|
builder.put(dataset.read())
|
||||||
|
@ -25,7 +16,7 @@ def convert_fm9g_data(fm9g_path, out_path):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--input", "-i", required=True, help="Data path in fm9g format.")
|
parser.add_argument("--input", "-i", required=True, help="Data path in CPM format.")
|
||||||
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
|
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_fm9g_data(args.input, args.output)
|
convert_cpm_data(args.input, args.output)
|
|
@ -1,10 +1,3 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
|
||||||
#
|
|
||||||
# @author: ouzebin <ouzebin@zhihu.com>
|
|
||||||
# @date: 2023/08/07
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -1,18 +1,3 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
@ -266,7 +251,7 @@ def merge_dataset(dst: str, src: str):
|
||||||
_write_info_list(meta_path_dst, nw_info)
|
_write_info_list(meta_path_dst, nw_info)
|
||||||
|
|
||||||
|
|
||||||
def to_fm9g(src_data, dst_path, dst_name):
|
def to_cpm(src_data, dst_path, dst_name):
|
||||||
if not os.path.exists(dst_path):
|
if not os.path.exists(dst_path):
|
||||||
os.makedirs(dst_path)
|
os.makedirs(dst_path)
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .generation_utils import apply_repetition_penalty
|
||||||
|
from .generation_utils import BeamHypotheses
|
||||||
|
from .generation_utils import NBCE
|
||||||
|
from .generation_utils import top_k_top_p_filtering
|
|
@ -0,0 +1,127 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
|
||||||
|
# This function has been mostly taken from huggingface conversational ai code at
|
||||||
|
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
batch_size = logits.size()[0]
|
||||||
|
if top_p > 0.0:
|
||||||
|
logits = logits.view(batch_size, -1).contiguous()
|
||||||
|
for index in range(len(logits)):
|
||||||
|
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
# Remove tokens with cumulative probability above the threshold
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||||
|
logits[index][indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
logits = logits.view(batch_size, -1).contiguous()
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def apply_repetition_penalty(
|
||||||
|
logits,
|
||||||
|
batch_size,
|
||||||
|
num_beams,
|
||||||
|
prev_output_tokens,
|
||||||
|
repetition_penalty,
|
||||||
|
start_idx=None,
|
||||||
|
end_idx=None,
|
||||||
|
window_size=None,
|
||||||
|
):
|
||||||
|
# only conduct repetition penalty for the output
|
||||||
|
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
|
||||||
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
|
for i in range(batch_size * num_beams):
|
||||||
|
if start_idx is None or end_idx is None:
|
||||||
|
output_tokens = prev_output_tokens[i].tolist()
|
||||||
|
else:
|
||||||
|
if end_idx >= start_idx:
|
||||||
|
if window_size:
|
||||||
|
output_tokens = prev_output_tokens[i][
|
||||||
|
max(start_idx, end_idx + 1 - window_size) : end_idx + 1
|
||||||
|
].tolist()
|
||||||
|
else:
|
||||||
|
output_tokens = prev_output_tokens[i][start_idx : end_idx + 1].tolist()
|
||||||
|
else:
|
||||||
|
output_tokens = []
|
||||||
|
for previous_token in set(output_tokens):
|
||||||
|
# if score < 0 then repetition penalty has to
|
||||||
|
# multiplied to reduce the previous token probability
|
||||||
|
if logits[i, previous_token] < 0:
|
||||||
|
logits[i, previous_token] *= repetition_penalty
|
||||||
|
else:
|
||||||
|
logits[i, previous_token] /= repetition_penalty
|
||||||
|
|
||||||
|
|
||||||
|
class BeamHypotheses:
|
||||||
|
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
|
||||||
|
"""
|
||||||
|
Initialize n-best list of hypotheses.
|
||||||
|
"""
|
||||||
|
self.max_len = max_len
|
||||||
|
self.length_penalty = length_penalty
|
||||||
|
self.early_stopping = early_stopping
|
||||||
|
self.n_hyp = n_hyp
|
||||||
|
self.hyp = []
|
||||||
|
self.worst_score = 1e9
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Number of hypotheses in the list.
|
||||||
|
"""
|
||||||
|
return len(self.hyp)
|
||||||
|
|
||||||
|
def add(self, hyp, sum_logprobs):
|
||||||
|
"""
|
||||||
|
Add a new hypothesis to the list.
|
||||||
|
"""
|
||||||
|
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||||
|
|
||||||
|
if len(self) < self.n_hyp or score > self.worst_score:
|
||||||
|
self.hyp.append((score, hyp))
|
||||||
|
if len(self) > self.n_hyp:
|
||||||
|
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||||
|
del self.hyp[sorted_scores[0][1]]
|
||||||
|
self.worst_score = sorted_scores[1][0]
|
||||||
|
else:
|
||||||
|
self.worst_score = min(score, self.worst_score)
|
||||||
|
|
||||||
|
def is_done(self, best_sum_logprobs, cur_len):
|
||||||
|
"""
|
||||||
|
If there are enough hypotheses and that none of the hypotheses being generated
|
||||||
|
can become better than the worst one in the heap, then we are done with this sentence.
|
||||||
|
"""
|
||||||
|
if len(self) < self.n_hyp:
|
||||||
|
return False
|
||||||
|
elif self.early_stopping:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
|
||||||
|
|
||||||
|
|
||||||
|
def NBCE(logits):
|
||||||
|
"""
|
||||||
|
Naive Bayes-based Context Extension
|
||||||
|
blog: https://www.kexue.fm/archives/9617
|
||||||
|
"""
|
||||||
|
beta = 0.25
|
||||||
|
logits = logits[:, -1] # bsh -> bh
|
||||||
|
logits = logits - logits.logsumexp(dim=-1, keepdims=True)
|
||||||
|
k = (logits.exp() * logits).sum(dim=-1)[1:].argmax() + 1
|
||||||
|
logits_max = logits[k]
|
||||||
|
logits_uncond = logits[0]
|
||||||
|
logits = (1 + beta) * logits_max - beta * logits_uncond
|
||||||
|
return logits
|
|
@ -12,4 +12,3 @@ from .position_embedding import RotaryEmbedding
|
||||||
from .position_embedding import RotaryEmbeddingESM
|
from .position_embedding import RotaryEmbeddingESM
|
||||||
from .position_embedding import SegmentPositionEmbedding
|
from .position_embedding import SegmentPositionEmbedding
|
||||||
from .transformer import Encoder
|
from .transformer import Encoder
|
||||||
#from _attention_pp_sp import OpAttnPipeSP
|
|
|
@ -13,13 +13,13 @@ from einops import rearrange
|
||||||
from .linear import ColumnParallelLinear
|
from .linear import ColumnParallelLinear
|
||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .position_embedding import apply_chatglm_rotary_pos_emb
|
from .position_embedding import apply_chatglm_rotary_pos_emb
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
#try:
|
try:
|
||||||
# from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
|
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
|
||||||
# from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
|
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
|
||||||
# from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
#except:
|
except:
|
||||||
# flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.bert_padding import pad_input
|
from flash_attn.bert_padding import pad_input
|
||||||
|
@ -54,8 +54,6 @@ class OpFlash(torch.autograd.Function):
|
||||||
dropout_p,
|
dropout_p,
|
||||||
ctx.softmax_scale,
|
ctx.softmax_scale,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
window_size=(-1, -1),
|
|
||||||
alibi_slopes=None,
|
|
||||||
return_softmax=False,
|
return_softmax=False,
|
||||||
)
|
)
|
||||||
if record:
|
if record:
|
||||||
|
@ -87,9 +85,6 @@ class OpFlash(torch.autograd.Function):
|
||||||
ctx.dropout_p,
|
ctx.dropout_p,
|
||||||
ctx.softmax_scale,
|
ctx.softmax_scale,
|
||||||
ctx.causal,
|
ctx.causal,
|
||||||
(-1,-1),
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
rng_state=rng_state,
|
rng_state=rng_state,
|
||||||
)
|
)
|
||||||
return None, None, dq, dk, dv, None, None, None, None
|
return None, None, dq, dk, dv, None, None, None, None
|
||||||
|
@ -299,28 +294,12 @@ class Attention(bmt.DistributedModule):
|
||||||
h_q, h_k = position_bias(
|
h_q, h_k = position_bias(
|
||||||
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
|
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
|
||||||
)
|
)
|
||||||
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
# score = flash_attn_varlen_func(
|
||||||
print("h_q: ", h_q)
|
# h_q, h_k, h_v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.dropout_p, causal=True
|
||||||
print("cu_seqlens: ", cu_seqlens)
|
# )
|
||||||
print("max_seqlen: ", max_seqlen)
|
score = OpFlash.apply(
|
||||||
score = flash_attn_varlen_func(
|
self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
|
||||||
h_q,
|
|
||||||
h_k,
|
|
||||||
h_v,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
max_seqlen,
|
|
||||||
self.dropout_p,
|
|
||||||
causal=True,
|
|
||||||
deterministic=True,
|
|
||||||
)
|
)
|
||||||
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
|
||||||
# Rongqiao change
|
|
||||||
#score = OpFlash.apply(
|
|
||||||
# self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
|
|
||||||
#)
|
|
||||||
|
|
||||||
|
|
||||||
score = score.view(batch_size, len_q, -1)
|
score = score.view(batch_size, len_q, -1)
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
from .attention import Attention
|
||||||
|
from .blocks import TransformerBlock
|
||||||
|
from .embedding import Embedding
|
||||||
|
from .embedding import EmbeddingExt
|
||||||
|
from .feedforward import FeedForward
|
||||||
|
from .layernorm import LayerNorm
|
||||||
|
from .linear import Linear
|
||||||
|
from .position_embedding import BucketPositionBias
|
||||||
|
from .position_embedding import RotaryEmbedding
|
||||||
|
from .position_embedding import RotaryEmbeddingESM
|
||||||
|
from .position_embedding import SegmentPositionEmbedding
|
||||||
|
from .transformer import Encoder
|
|
@ -0,0 +1,134 @@
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from .linear import Linear
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_model: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
dim_head: int,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
dropout_p: Optional[float] = None,
|
||||||
|
use_flash_attn: bool = False,
|
||||||
|
scale: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.head_groups = num_heads // num_kv_heads
|
||||||
|
|
||||||
|
self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype, scale=scale)
|
||||||
|
self.project_k = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
|
||||||
|
self.project_v = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
|
||||||
|
|
||||||
|
self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale)
|
||||||
|
|
||||||
|
self.softmax = torch.nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
if dropout_p is not None:
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout_p)
|
||||||
|
else:
|
||||||
|
self.dropout = None
|
||||||
|
|
||||||
|
# if use_flash_attn:
|
||||||
|
# self.core_attention_flash = FlashSelfAttention(causal=False, attention_dropout=0.0)
|
||||||
|
# self.use_flash_attn = use_flash_attn
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_q: torch.Tensor,
|
||||||
|
hidden_kv: torch.Tensor,
|
||||||
|
attention_mask: torch.BoolTensor,
|
||||||
|
position_bias: torch.Tensor,
|
||||||
|
use_cache: bool = False,
|
||||||
|
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
pos_bias_type: Optional[str] = "relative",
|
||||||
|
length_mask: Optional[torch.Tensor] = None,
|
||||||
|
context_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
|
||||||
|
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
|
||||||
|
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
|
||||||
|
Return:
|
||||||
|
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
batch_size = hidden_q.size(0)
|
||||||
|
len_q = hidden_q.size(1)
|
||||||
|
len_k = hidden_kv.size(1)
|
||||||
|
|
||||||
|
h_q = self.project_q(hidden_q)
|
||||||
|
h_k = self.project_k(hidden_kv)
|
||||||
|
h_v = self.project_v(hidden_kv)
|
||||||
|
|
||||||
|
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
|
||||||
|
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
|
||||||
|
|
||||||
|
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||||
|
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||||
|
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if pos_bias_type == "rotary":
|
||||||
|
# b h s d
|
||||||
|
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
|
||||||
|
|
||||||
|
if past_kv is not None:
|
||||||
|
h_k = torch.cat([past_kv[0], h_k], dim=-2)
|
||||||
|
h_v = torch.cat([past_kv[1], h_v], dim=-2)
|
||||||
|
len_k = h_k.size(-2)
|
||||||
|
|
||||||
|
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
|
||||||
|
if self.head_groups == 1:
|
||||||
|
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
|
||||||
|
else:
|
||||||
|
score = torch.matmul(
|
||||||
|
h_q.reshape(batch_size, self.num_kv_heads, self.head_groups * len_q, self.dim_head),
|
||||||
|
h_k.transpose(-1, -2),
|
||||||
|
).view(batch_size, self.num_heads, len_q, len_k)
|
||||||
|
|
||||||
|
if pos_bias_type == "relative":
|
||||||
|
score = score + position_bias
|
||||||
|
score = torch.masked_fill(
|
||||||
|
score,
|
||||||
|
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||||
|
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
score = self.softmax(score)
|
||||||
|
|
||||||
|
score = torch.masked_fill(
|
||||||
|
score,
|
||||||
|
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||||
|
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.dropout is not None:
|
||||||
|
score = self.dropout(score)
|
||||||
|
|
||||||
|
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
|
||||||
|
score = torch.matmul(score.view(batch_size, self.num_kv_heads, self.head_groups * len_q, len_k), h_v).view(
|
||||||
|
batch_size, self.num_heads, len_q, self.dim_head
|
||||||
|
)
|
||||||
|
|
||||||
|
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
||||||
|
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
||||||
|
|
||||||
|
score = self.attention_out(score)
|
||||||
|
if use_cache:
|
||||||
|
return score, (h_k, h_v)
|
||||||
|
else:
|
||||||
|
return score
|
|
@ -0,0 +1,279 @@
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .attention import Attention
|
||||||
|
from .feedforward import FeedForward
|
||||||
|
from .layernorm import LayerNorm
|
||||||
|
from .position_embedding import RotaryEmbedding
|
||||||
|
from .position_embedding import RotaryEmbeddingESM
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionBlock(torch.nn.Module):
|
||||||
|
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim_model (int): main dimension of modules in transformer blocks.
|
||||||
|
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
||||||
|
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
||||||
|
dtype (optional): Defaults to torch.half.
|
||||||
|
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
||||||
|
dropout_p (float, optional): Defaults to 0.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_model: int,
|
||||||
|
num_heads: int,
|
||||||
|
dim_head: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
dtype=torch.half,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
dropout_p: Optional[float] = None,
|
||||||
|
scale: bool = True,
|
||||||
|
use_flash_attn: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layernorm_before_attention = LayerNorm(
|
||||||
|
dim_model,
|
||||||
|
dtype=dtype,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.self_attention = Attention(
|
||||||
|
dim_model=dim_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dtype=dtype,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
scale=scale,
|
||||||
|
use_flash_attn=use_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dropout_p:
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_p)
|
||||||
|
else:
|
||||||
|
self.dropout = None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
position_bias: Union[torch.Tensor, RotaryEmbedding, RotaryEmbeddingESM] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
pos_bias_type: Optional[str] = "relative",
|
||||||
|
length_mask: Optional[torch.Tensor] = None,
|
||||||
|
context_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences.
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation.
|
||||||
|
position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block.
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
x = self.layernorm_before_attention(hidden_states)
|
||||||
|
x = self.self_attention(
|
||||||
|
x,
|
||||||
|
x,
|
||||||
|
attention_mask,
|
||||||
|
position_bias,
|
||||||
|
use_cache,
|
||||||
|
past_key_value,
|
||||||
|
pos_bias_type=pos_bias_type,
|
||||||
|
length_mask=length_mask,
|
||||||
|
context_mask=context_mask,
|
||||||
|
)
|
||||||
|
if use_cache:
|
||||||
|
x, current_key_value = x
|
||||||
|
else:
|
||||||
|
current_key_value = None
|
||||||
|
|
||||||
|
if self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + x
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
return hidden_states, current_key_value
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FFNBlock(torch.nn.Module):
|
||||||
|
"""The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim_model (int): main dimension of modules in transformer blocks.
|
||||||
|
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
||||||
|
dtype (optional): Defaults to torch.half.
|
||||||
|
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
||||||
|
dropout_p (float, optional): Defaults to 0.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_model: int,
|
||||||
|
dim_ff: int,
|
||||||
|
activate_fn: str,
|
||||||
|
dtype=torch.half,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
dropout_p: Optional[float] = 0,
|
||||||
|
scale: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layernorm_before_ffn = LayerNorm(
|
||||||
|
dim_model,
|
||||||
|
dtype=dtype,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ffn = FeedForward(
|
||||||
|
dim_model,
|
||||||
|
dim_ff,
|
||||||
|
activate_fn=activate_fn,
|
||||||
|
dtype=dtype,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dropout_p:
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_p)
|
||||||
|
else:
|
||||||
|
self.dropout = None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
x = self.layernorm_before_ffn(hidden_states)
|
||||||
|
x = self.ffn(x)
|
||||||
|
if self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + x
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(torch.nn.Module):
|
||||||
|
"""The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim_model (int): main dimension of modules in transformer blocks.
|
||||||
|
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
||||||
|
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
||||||
|
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
||||||
|
dtype (optional): Defaults to torch.half.
|
||||||
|
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
|
||||||
|
dropout_p (float, optional): Defaults to 0.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_model: int,
|
||||||
|
dim_ff: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
dim_head: int,
|
||||||
|
activate_fn: str = "gelu",
|
||||||
|
dtype=torch.half,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
dropout_p: Optional[float] = None,
|
||||||
|
scale: bool = True,
|
||||||
|
mask_att: bool = False,
|
||||||
|
mask_ffn: bool = False,
|
||||||
|
use_flash_attn: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mask_att = mask_att
|
||||||
|
self.mask_ffn = mask_ffn
|
||||||
|
|
||||||
|
if not self.mask_att:
|
||||||
|
self.self_att = SelfAttentionBlock(
|
||||||
|
dim_model=dim_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dtype=dtype,
|
||||||
|
eps=eps,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
scale=scale,
|
||||||
|
use_flash_attn=use_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.mask_ffn:
|
||||||
|
self.ffn = FFNBlock(
|
||||||
|
dim_model=dim_model,
|
||||||
|
dim_ff=dim_ff,
|
||||||
|
activate_fn=activate_fn,
|
||||||
|
dtype=dtype,
|
||||||
|
eps=eps,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
self_hidden_states: torch.Tensor,
|
||||||
|
self_attention_mask: torch.Tensor,
|
||||||
|
self_position_bias: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
pos_bias_type: Optional[str] = "relative",
|
||||||
|
length_mask: Optional[torch.Tensor] = None,
|
||||||
|
context_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
||||||
|
self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention.
|
||||||
|
self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block.
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
# (batch, dim_model, seq_self)
|
||||||
|
current_key_value = None
|
||||||
|
if not self.mask_att:
|
||||||
|
hidden_states = self.self_att(
|
||||||
|
self_hidden_states,
|
||||||
|
attention_mask=self_attention_mask,
|
||||||
|
position_bias=self_position_bias,
|
||||||
|
use_cache=use_cache,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
pos_bias_type=pos_bias_type,
|
||||||
|
length_mask=length_mask,
|
||||||
|
context_mask=context_mask,
|
||||||
|
)
|
||||||
|
if use_cache:
|
||||||
|
hidden_states, current_key_value = hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states = self_hidden_states
|
||||||
|
|
||||||
|
# (batch, dim_model, seq_self)
|
||||||
|
if not self.mask_ffn:
|
||||||
|
hidden_states = self.ffn(hidden_states)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
return hidden_states, current_key_value
|
||||||
|
else:
|
||||||
|
return hidden_states
|
|
@ -0,0 +1,100 @@
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .position_embedding import RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
embedding_size: int,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
scale: bool = True,
|
||||||
|
init_mean: float = 0.0,
|
||||||
|
init_std: float = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim_model = embedding_size
|
||||||
|
self.weight = torch.nn.parameter.Parameter(torch.empty(vocab_size, embedding_size, dtype=dtype))
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
if self.scale:
|
||||||
|
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
||||||
|
else:
|
||||||
|
embeds = F.embedding(ids, self.weight)
|
||||||
|
return embeds.clone()
|
||||||
|
|
||||||
|
def projection(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
||||||
|
Args:
|
||||||
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
|
||||||
|
""" # noqa: E501
|
||||||
|
if self.scale:
|
||||||
|
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
||||||
|
else:
|
||||||
|
logits = F.linear(x, self.weight)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingExt(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
embedding_size: int,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
init_mean: float = 0.0,
|
||||||
|
init_std: float = 1,
|
||||||
|
distance_scale: int = 16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim_model = embedding_size
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty(vocab_size, embedding_size, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
|
||||||
|
ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
|
||||||
|
return self.rotary_emb(embeds, ids_sub)
|
||||||
|
|
||||||
|
def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
|
||||||
|
Args:
|
||||||
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
|
||||||
|
ext_table (:obj:`torch.Tensor` of shape ``(ext_table_size, dim_model)``): Ext vocab table.
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_size + ext_table_size)``: The projection output.
|
||||||
|
""" # noqa: E501
|
||||||
|
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
|
||||||
|
if ext_table is not None:
|
||||||
|
logits_ext = F.linear(x, ext_table)
|
||||||
|
logits = torch.cat([logits, logits_ext], dim=-1)
|
||||||
|
return logits
|
|
@ -0,0 +1,120 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .linear import Linear
|
||||||
|
|
||||||
|
|
||||||
|
class DenseGatedACT(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in: int,
|
||||||
|
dim_ff: int,
|
||||||
|
dtype=torch.half,
|
||||||
|
activate_fn: str = "gelu",
|
||||||
|
scale: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.w_0 = Linear(
|
||||||
|
dim_in=dim_in,
|
||||||
|
dim_out=dim_ff,
|
||||||
|
dtype=dtype,
|
||||||
|
scale=scale,
|
||||||
|
scale_before=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.w_1 = Linear(
|
||||||
|
dim_in=dim_in,
|
||||||
|
dim_out=dim_ff,
|
||||||
|
dtype=dtype,
|
||||||
|
scale=scale,
|
||||||
|
scale_before=False,
|
||||||
|
)
|
||||||
|
if activate_fn == "gelu":
|
||||||
|
self.act = torch.nn.GELU()
|
||||||
|
elif activate_fn == "silu":
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{activate_fn} is not supported")
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
out (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_ff)``)
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
gate_score = self.act(self.w_0(x))
|
||||||
|
x = self.w_1(x)
|
||||||
|
|
||||||
|
x = gate_score * x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(torch.nn.Module):
|
||||||
|
r"""FeedForward module
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim_in (int): input dimension.
|
||||||
|
dim_ff (int): middle dimension.
|
||||||
|
dim_out (int, optional): output dimension. Defaults to None, which means dim_in = dim_out.
|
||||||
|
dtype (optional): Defaults to torch.half.
|
||||||
|
init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.
|
||||||
|
init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02.
|
||||||
|
bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False.
|
||||||
|
activate_fn (str, optional): Defaults to `gated_gelu`.
|
||||||
|
dropout_p (int, optional): Defaults to 0.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_model: int,
|
||||||
|
dim_ff: int,
|
||||||
|
activate_fn: str = "gelu",
|
||||||
|
dtype=torch.half,
|
||||||
|
dropout_p: Optional[float] = None,
|
||||||
|
scale: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.w_in = DenseGatedACT(
|
||||||
|
dim_in=dim_model,
|
||||||
|
dim_ff=dim_ff,
|
||||||
|
activate_fn=activate_fn,
|
||||||
|
dtype=dtype,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dropout_p is not None:
|
||||||
|
self.dropout = torch.nn.Dropout(dropout_p)
|
||||||
|
else:
|
||||||
|
self.dropout = None
|
||||||
|
|
||||||
|
self.w_out = Linear(
|
||||||
|
dim_in=dim_ff,
|
||||||
|
dim_out=dim_model,
|
||||||
|
dtype=dtype,
|
||||||
|
scale=scale,
|
||||||
|
scale_before=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of feed-forward module.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of feed-forward module.
|
||||||
|
""" # noqa: E501
|
||||||
|
x = self.w_in(x)
|
||||||
|
|
||||||
|
if self.dropout is not None:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
x = self.w_out(x)
|
||||||
|
|
||||||
|
return x
|
|
@ -0,0 +1,37 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
||||||
|
old_dtype = hidden.dtype
|
||||||
|
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
|
||||||
|
return hidden * weight
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(torch.nn.Module):
|
||||||
|
"""RMS LayerNorm"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_norm: int,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
init_var: float = 1.0,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
self.dim_norm = dim_norm
|
||||||
|
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``): Input tensor that need to be normalized.
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output.
|
||||||
|
"""
|
||||||
|
assert x.size(-1) == self.dim_norm
|
||||||
|
return rms_layernorm(x, self.weight, self.eps)
|
|
@ -0,0 +1,44 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in: int,
|
||||||
|
dim_out: int,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
init_mean: float = 0.0,
|
||||||
|
init_std: float = 1,
|
||||||
|
scale: bool = True,
|
||||||
|
scale_before: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim_in = self.in_features = dim_in
|
||||||
|
self.dim_out = self.out_features = dim_out
|
||||||
|
self.scale = scale
|
||||||
|
self.scale_before = scale_before
|
||||||
|
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
|
||||||
|
torch.nn.init.normal_(self.weight, mean=init_mean, std=init_std)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
||||||
|
""" # noqa: E501
|
||||||
|
if self.scale:
|
||||||
|
if self.scale_before:
|
||||||
|
x = x / math.sqrt(self.dim_in)
|
||||||
|
|
||||||
|
x = F.linear(x, self.weight)
|
||||||
|
else:
|
||||||
|
x = F.linear(x, self.weight)
|
||||||
|
x = x / math.sqrt(self.dim_in)
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = F.linear(x, self.weight)
|
||||||
|
return x
|
|
@ -0,0 +1,286 @@
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentPositionEmbedding(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
num_segments: int = 1,
|
||||||
|
num_buckets: int = 32,
|
||||||
|
max_distance: int = 128,
|
||||||
|
bidirectional: bool = False,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
init_mean: float = 0.0,
|
||||||
|
init_std: float = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.num_segments = num_segments
|
||||||
|
|
||||||
|
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
key_pos: torch.Tensor,
|
||||||
|
query_pos: torch.Tensor,
|
||||||
|
key_segment: torch.Tensor,
|
||||||
|
query_segment: torch.Tensor,
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch = key_pos.size(0)
|
||||||
|
keylen = key_pos.size(1)
|
||||||
|
querylen = query_pos.size(1)
|
||||||
|
|
||||||
|
assert key_pos.size(0) == query_pos.size(0)
|
||||||
|
assert keylen == key_segment.size(1) and querylen == query_segment.size(1)
|
||||||
|
|
||||||
|
key_pos = key_pos.view(batch, -1, keylen)
|
||||||
|
query_pos = query_pos.view(batch, querylen, -1)
|
||||||
|
key_segment = key_segment.view(batch, -1, keylen)
|
||||||
|
query_segment = query_segment.view(batch, querylen, -1)
|
||||||
|
|
||||||
|
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
|
||||||
|
relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠
|
||||||
|
|
||||||
|
# b*q*k
|
||||||
|
absolute_position_bucket = self._position_bucket(
|
||||||
|
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
|
||||||
|
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
|
||||||
|
bidirectional=self.bidirectional,
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
)
|
||||||
|
relative_position_bucket = torch.where(
|
||||||
|
(key_segment == query_segment),
|
||||||
|
absolute_position_bucket[None, :, :],
|
||||||
|
relative_position_bucket,
|
||||||
|
)
|
||||||
|
# (batch, len_q, len_k)
|
||||||
|
|
||||||
|
# (batch, len_q, len_k, num_heads)
|
||||||
|
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
||||||
|
# (batch, num_heads, len_q, len_k)
|
||||||
|
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
def _segment_relative_position_bucket(self, query_segment, key_segment):
|
||||||
|
return query_segment * self.num_segments + key_segment
|
||||||
|
|
||||||
|
def _position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||||
|
relative_buckets = 0
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
||||||
|
relative_position = torch.abs(relative_position)
|
||||||
|
else:
|
||||||
|
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
relative_postion_if_large = max_exact + (
|
||||||
|
torch.log(relative_position.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.int32)
|
||||||
|
relative_postion_if_large = torch.min(
|
||||||
|
relative_postion_if_large,
|
||||||
|
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
||||||
|
)
|
||||||
|
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
|
||||||
|
class BucketPositionBias(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
num_buckets: int = 32,
|
||||||
|
num_segment_bucket: int = 32,
|
||||||
|
max_distance: int = 128,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
init_mean: float = 0.0,
|
||||||
|
init_std: float = 1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.num_segment_bucket = num_segment_bucket
|
||||||
|
self.max_distance = max_distance
|
||||||
|
|
||||||
|
self.relative_attention_bias = torch.nn.parameter.Parameter(
|
||||||
|
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query_pos: torch.Tensor, # (batch, len_q)
|
||||||
|
key_pos: torch.Tensor, # (batch, len_k)
|
||||||
|
rel_buckets: torch.Tensor, # (batch, len_q, len_k)
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch = key_pos.size(0)
|
||||||
|
keylen = key_pos.size(1)
|
||||||
|
querylen = query_pos.size(1)
|
||||||
|
|
||||||
|
assert key_pos.size(0) == query_pos.size(0)
|
||||||
|
assert rel_buckets.size(0) == batch and rel_buckets.size(1) == querylen and rel_buckets.size(2) == keylen
|
||||||
|
|
||||||
|
relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠
|
||||||
|
|
||||||
|
# b*q*k
|
||||||
|
inner_segment_bucket = self._position_bucket(
|
||||||
|
key_pos[..., None, :] - query_pos[..., :, None],
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
)
|
||||||
|
relative_position_bucket = torch.where(
|
||||||
|
rel_buckets == 0,
|
||||||
|
inner_segment_bucket,
|
||||||
|
relative_position_bucket,
|
||||||
|
)
|
||||||
|
# (batch, len_q, len_k)
|
||||||
|
|
||||||
|
# (batch, len_q, len_k, num_heads)
|
||||||
|
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
||||||
|
# (batch, num_heads, len_q, len_k)
|
||||||
|
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
|
||||||
|
relative_buckets = 0
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
||||||
|
relative_position = torch.abs(relative_position)
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
relative_postion_if_large = max_exact + (
|
||||||
|
torch.log(relative_position.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.int32)
|
||||||
|
relative_postion_if_large = torch.min(
|
||||||
|
relative_postion_if_large,
|
||||||
|
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
||||||
|
)
|
||||||
|
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
base=10000,
|
||||||
|
distance_scale: Union[int, float] = 1,
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
||||||
|
inv_freq = inv_freq.to(dtype)
|
||||||
|
self.distance_scale = distance_scale
|
||||||
|
self.dtype = dtype
|
||||||
|
self.inv_freq = inv_freq
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, x_pos: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (:obj:`torch.Tensor` of shape ``(..., dim)``): Inputs.
|
||||||
|
x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs.
|
||||||
|
"""
|
||||||
|
x_pos = x_pos * self.distance_scale
|
||||||
|
freqs = x_pos[..., None].to(self.dtype) * self.inv_freq[None, :] # (..., dim/2)
|
||||||
|
|
||||||
|
# the same implementation as sat
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1) # (..., dim)
|
||||||
|
emb_cos = emb.cos() # (..., dim)
|
||||||
|
emb_sin = emb.sin() # (..., dim)
|
||||||
|
|
||||||
|
rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim)
|
||||||
|
|
||||||
|
return x * emb_cos + rotate_x * emb_sin
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset):
|
||||||
|
if x.size(seq_dim) < cos.size(seq_dim):
|
||||||
|
cos = cos.narrow(seq_dim, offset, x.size(seq_dim))
|
||||||
|
sin = sin.narrow(seq_dim, offset, x.size(seq_dim))
|
||||||
|
return (x * cos) + (rotate_half(x) * sin)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbeddingESM(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Rotary position embeddings based on those in
|
||||||
|
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
|
||||||
|
matrices which depend on their relative positions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
base: Union[int, float] = 10000,
|
||||||
|
distance_scale: Union[int, float] = 1,
|
||||||
|
dtype=torch.half,
|
||||||
|
persistent=True,
|
||||||
|
mixed_precision=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.base = base
|
||||||
|
self.distance_scale = distance_scale
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Generate and save the inverse frequency buffer (non trainable)
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
|
||||||
|
if mixed_precision:
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=persistent)
|
||||||
|
else:
|
||||||
|
self.register_buffer("inv_freq", inv_freq.to(self.dtype), persistent=persistent)
|
||||||
|
|
||||||
|
self._seq_len_cached = -1
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self.mixed_precision = mixed_precision
|
||||||
|
|
||||||
|
self.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
|
|
||||||
|
def _update_cos_sin_tables(self, x, seq_dim, offset):
|
||||||
|
seq_len = x.size(seq_dim) + offset
|
||||||
|
if seq_len > self._seq_len_cached or self._cos_cached.device != x.device:
|
||||||
|
self._seq_len_cached = seq_len
|
||||||
|
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
||||||
|
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
for i in range(x.dim() - 1):
|
||||||
|
if i != seq_dim:
|
||||||
|
emb = emb.unsqueeze_(i)
|
||||||
|
if self.mixed_precision:
|
||||||
|
self._cos_cached = emb.cos().to(self.dtype)
|
||||||
|
self._sin_cached = emb.sin().to(self.dtype)
|
||||||
|
else:
|
||||||
|
self._cos_cached = emb.cos()
|
||||||
|
self._sin_cached = emb.sin()
|
||||||
|
return self._cos_cached, self._sin_cached
|
||||||
|
|
||||||
|
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
seq_dim = (seq_dim + k.dim()) % k.dim()
|
||||||
|
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, offset)
|
||||||
|
return (
|
||||||
|
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||||
|
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
|
||||||
|
)
|
|
@ -0,0 +1,132 @@
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .blocks import TransformerBlock
|
||||||
|
from .layernorm import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(torch.nn.Module):
|
||||||
|
"""Layers of encoder transformer blocks plus an final layernorm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_layers (int): number of layers.
|
||||||
|
dim_model (int): main dimension of modules in transformer blocks.
|
||||||
|
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
|
||||||
|
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
|
||||||
|
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
|
||||||
|
dtype (optional): Defaults to torch.half.
|
||||||
|
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
|
||||||
|
dropout_p (float, optional): Defaults to 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
dim_model: int,
|
||||||
|
dim_ff: int,
|
||||||
|
num_heads: int,
|
||||||
|
dim_head: int,
|
||||||
|
num_kv_heads: int = -1,
|
||||||
|
activate_fn: str = "gelu",
|
||||||
|
dtype: torch.dtype = torch.half,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
dropout_p: Optional[float] = None,
|
||||||
|
scale: bool = True,
|
||||||
|
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
|
||||||
|
use_flash_attn: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if num_kv_heads == -1:
|
||||||
|
num_kv_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
if mask_modules is not None:
|
||||||
|
assert len(mask_modules) == num_layers, "The total number of masks should equal to num_layers"
|
||||||
|
for mask_module in mask_modules:
|
||||||
|
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
|
||||||
|
else:
|
||||||
|
mask_modules = [(False, False)] * num_layers
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerBlock(
|
||||||
|
dim_model=dim_model,
|
||||||
|
dim_ff=dim_ff,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
activate_fn=activate_fn,
|
||||||
|
dtype=dtype,
|
||||||
|
eps=eps,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
scale=scale,
|
||||||
|
mask_att=mask_modules[ith][0],
|
||||||
|
mask_ffn=mask_modules[ith][1],
|
||||||
|
use_flash_attn=use_flash_attn,
|
||||||
|
)
|
||||||
|
for ith in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
position_bias: torch.Tensor,
|
||||||
|
use_cache: bool = False,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
|
pos_bias_type: Optional[str] = "relative",
|
||||||
|
length_mask: Optional[torch.Tensor] = None,
|
||||||
|
context_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden-states (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Input of encoder, might be the embedding of a batch of sequences.
|
||||||
|
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc, seq_enc)``): Avoid invalid areas to participate in the calculation
|
||||||
|
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, seq_enc, seq_enc)``) Provides position information to attention mechanism.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``: The encoder output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not use_cache:
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_bias,
|
||||||
|
pos_bias_type=pos_bias_type,
|
||||||
|
length_mask=length_mask,
|
||||||
|
context_mask=context_mask,
|
||||||
|
)
|
||||||
|
hidden_states = self.output_layernorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
current_key_values = []
|
||||||
|
current_hidden_states = []
|
||||||
|
for i, module in enumerate(self.layers):
|
||||||
|
hidden_states = module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_bias,
|
||||||
|
past_key_value=past_key_values[i] if past_key_values else None,
|
||||||
|
use_cache=use_cache,
|
||||||
|
pos_bias_type=pos_bias_type,
|
||||||
|
length_mask=length_mask,
|
||||||
|
context_mask=context_mask,
|
||||||
|
)
|
||||||
|
if use_cache:
|
||||||
|
current_key_values.append(hidden_states[1])
|
||||||
|
current_hidden_states.append(hidden_states[0])
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
hidden_states = self.output_layernorm(hidden_states)
|
||||||
|
if use_cache:
|
||||||
|
return hidden_states, current_key_values, current_hidden_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
|
@ -1,5 +1,7 @@
|
||||||
|
from .config import Config
|
||||||
|
from .data_utils import pad
|
||||||
|
from .data_utils import pad_raw
|
||||||
|
from .gradient_shrink import gradient_shrink
|
||||||
from .log import logger
|
from .log import logger
|
||||||
from .log import LogManager
|
from .log import LogManager
|
||||||
from .object import allgather_objects
|
from .object import allgather_objects
|
||||||
from .config import Config
|
|
||||||
from .gradient_shrink import gradient_shrink
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def pad(orig_items, key, padding_value=0, padding_side="left"):
|
||||||
|
items = []
|
||||||
|
if isinstance(orig_items[0][key], list):
|
||||||
|
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
||||||
|
for it in orig_items:
|
||||||
|
for tr in it[key]:
|
||||||
|
items.append({key: tr})
|
||||||
|
else:
|
||||||
|
assert isinstance(orig_items[0][key], torch.Tensor)
|
||||||
|
items = orig_items
|
||||||
|
|
||||||
|
batch_size = len(items)
|
||||||
|
shape = items[0][key].shape
|
||||||
|
dim = len(shape)
|
||||||
|
assert dim <= 3
|
||||||
|
max_length = max(item[key].shape[-1] for item in items)
|
||||||
|
min_length = min(item[key].shape[-1] for item in items)
|
||||||
|
dtype = items[0][key].dtype
|
||||||
|
|
||||||
|
if dim == 1:
|
||||||
|
return torch.cat([item[key] for item in items], dim=0)
|
||||||
|
elif dim == 2:
|
||||||
|
if max_length == min_length:
|
||||||
|
return torch.cat([item[key] for item in items], dim=0)
|
||||||
|
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||||
|
else:
|
||||||
|
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||||
|
|
||||||
|
for i, item in enumerate(items):
|
||||||
|
if dim == 2:
|
||||||
|
if padding_side == "left":
|
||||||
|
tensor[i, -len(item[key][0]) :] = item[key][0].clone()
|
||||||
|
else:
|
||||||
|
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
||||||
|
elif dim == 3:
|
||||||
|
if padding_side == "left":
|
||||||
|
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
||||||
|
else:
|
||||||
|
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def pad_raw(orig_items, max_length=1024, padding_value=0, padding_side="left"):
|
||||||
|
max_cols = max(tensor.size(1) for tensor in orig_items)
|
||||||
|
|
||||||
|
padded_arrays = []
|
||||||
|
for tensor in orig_items:
|
||||||
|
pad_cols = max_cols - tensor.size(1)
|
||||||
|
if padding_side == "left":
|
||||||
|
padded_tensor = torch.cat([torch.zeros(tensor.size(0), pad_cols), tensor], dim=1)
|
||||||
|
elif padding_side == "right":
|
||||||
|
padded_tensor = torch.cat([tensor, torch.zeros(tensor.size(0), pad_cols)], dim=1)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid 'side' parameter. Must be 'left' or 'right'.")
|
||||||
|
padded_arrays.append(padded_tensor)
|
||||||
|
|
||||||
|
padded_tensor = torch.cat(padded_arrays, dim=0).to(dtype=torch.int32)
|
||||||
|
return padded_tensor
|
|
@ -0,0 +1,130 @@
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import bmtrain as bmt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .log import logger
|
||||||
|
|
||||||
|
|
||||||
|
def rename_if_exists(file_path):
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return
|
||||||
|
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||||
|
file_dir, file_name = os.path.split(file_path)
|
||||||
|
file_root, file_ext = os.path.splitext(file_name)
|
||||||
|
new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
|
||||||
|
new_file_path = os.path.join(file_dir, new_file_name)
|
||||||
|
try:
|
||||||
|
os.rename(file_path, new_file_path)
|
||||||
|
logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(
|
||||||
|
"rename file failed,file_path={file_path}, new_file_path={new_file_path},err={err}".format(
|
||||||
|
file_path=file_path, new_file_path=new_file_path, err=str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rename_if_exists_decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(file_path, *args, **kwargs):
|
||||||
|
rename_if_exists(file_path)
|
||||||
|
return func(file_path, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@rename_if_exists_decorator
|
||||||
|
def bmt_save(file_path: str, model: torch.nn.Module, export_files: List[str] = None):
|
||||||
|
bmt.save(model, file_path)
|
||||||
|
if export_files is not None:
|
||||||
|
export_files.append(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@rename_if_exists_decorator
|
||||||
|
def torch_save(file_path: str, obj: object, export_files: List[str] = None):
|
||||||
|
torch.save(obj, file_path)
|
||||||
|
if export_files is not None:
|
||||||
|
export_files.append(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@rename_if_exists_decorator
|
||||||
|
def json_save(file_path: str, obj: object, export_files: List[str] = None):
|
||||||
|
with open(file_path, "w") as data_f:
|
||||||
|
json.dump(obj, data_f)
|
||||||
|
if export_files is not None:
|
||||||
|
export_files.append(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def export(
|
||||||
|
model: torch.nn.Module, dataloader, optimizer: bmt.optim.AdamOffloadOptimizer, global_step, args, final_save=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
一次 ckpt 保存:
|
||||||
|
/{args.save}/
|
||||||
|
├── {save_name}-{global_step}.rank-0.opt
|
||||||
|
├── {save_name}-{global_step}.rank-n.opt
|
||||||
|
├── job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时,job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
|
||||||
|
├── config.json
|
||||||
|
├── vocabs.txt
|
||||||
|
├── {args.save_name}-{global_step}.pt
|
||||||
|
├── {args.save_name}-{global_step}.data
|
||||||
|
├── {args.save_name}-{global_step}.data.json
|
||||||
|
└── {args.save_name}-{global_step}.success
|
||||||
|
|
||||||
|
"""
|
||||||
|
export_model_dir = os.path.join(args.save, f"l_{global_step}")
|
||||||
|
os.makedirs(export_model_dir, exist_ok=True)
|
||||||
|
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
|
||||||
|
logger.info(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
|
||||||
|
export_files = []
|
||||||
|
|
||||||
|
# model checkpoint
|
||||||
|
bmt_save(
|
||||||
|
file_path=os.path.join(export_model_dir, base_file_name + ".pt"),
|
||||||
|
model=model,
|
||||||
|
export_files=export_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
# opt is only used for continual pre-training, not the final opt
|
||||||
|
if not final_save:
|
||||||
|
grad_path = os.path.join(
|
||||||
|
args.save,
|
||||||
|
args.save_name + ("-%d.rank-%d.opt" % (global_step % (args.save_iters * 5), bmt.rank())),
|
||||||
|
)
|
||||||
|
torch.save(optimizer.state_dict(), grad_path)
|
||||||
|
logger.info(f"Successfully save grad file: {grad_path}")
|
||||||
|
|
||||||
|
all_states = dataloader.state_dict()
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
# data checkpoint
|
||||||
|
# rank 0 writes the dataloader state
|
||||||
|
torch_save(
|
||||||
|
file_path=os.path.join(export_model_dir, base_file_name + ".data"),
|
||||||
|
obj=all_states,
|
||||||
|
export_files=export_files,
|
||||||
|
)
|
||||||
|
# data checkpoint json
|
||||||
|
# rank 0 writes the dataloader state into the json file
|
||||||
|
data_p_json = {k: v for k, v in all_states.items()}
|
||||||
|
for k in data_p_json:
|
||||||
|
data_p_json[k] = {k_of_v: data_p_json[k][k_of_v].tolist() for k_of_v in data_p_json[k]}
|
||||||
|
json_save(
|
||||||
|
file_path=os.path.join(export_model_dir, base_file_name + ".data.json"),
|
||||||
|
obj=data_p_json,
|
||||||
|
export_files=export_files,
|
||||||
|
)
|
||||||
|
# config 和 vocabs 和模型文件一起存储
|
||||||
|
model_cfg_path = os.path.join(export_model_dir, "config.json")
|
||||||
|
model_vocab_path = os.path.join(export_model_dir, "vocabs.txt")
|
||||||
|
export_files.extend([model_cfg_path, model_vocab_path])
|
||||||
|
shutil.copy(args.model_config, model_cfg_path)
|
||||||
|
shutil.copy(args.vocab, model_vocab_path)
|
||||||
|
logger.info(f"Successfully save model files! {export_files}")
|
||||||
|
del all_states
|
||||||
|
return export_model_dir
|
|
@ -18,7 +18,7 @@ def _get_logger():
|
||||||
log.setLevel(log_level)
|
log.setLevel(log_level)
|
||||||
log.propagate = False
|
log.propagate = False
|
||||||
|
|
||||||
node_name = os.getenv("NODE_NAME", "jeeves-hpc-gpu00")
|
node_name = os.getenv("NODE_NAME", "hpc-gpu00")
|
||||||
|
|
||||||
fmt = f"[%(levelname)s][%(asctime)s][{node_name}][%(filename)s:%(lineno)d:%(process)d] - %(message)s"
|
fmt = f"[%(levelname)s][%(asctime)s][{node_name}][%(filename)s:%(lineno)d:%(process)d] - %(message)s"
|
||||||
formatter = logging.Formatter(fmt, datefmt="%Y-%m-%d %H:%M:%S")
|
formatter = logging.Formatter(fmt, datefmt="%Y-%m-%d %H:%M:%S")
|
||||||
|
@ -32,7 +32,6 @@ def _get_logger():
|
||||||
return log
|
return log
|
||||||
|
|
||||||
|
|
||||||
# 日志句柄
|
|
||||||
logger = _get_logger()
|
logger = _get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,10 +107,8 @@ class LogManager:
|
||||||
}
|
}
|
||||||
if model_inspect is not None:
|
if model_inspect is not None:
|
||||||
ret["model_inspect"] = model_inspect
|
ret["model_inspect"] = model_inspect
|
||||||
print(ret)
|
fp.write(json.dumps(ret, ensure_ascii=False) + "\n")
|
||||||
fp.write(json.dumps(ret) + "\n")
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(e)
|
|
||||||
print("Error: writing info list!")
|
print("Error: writing info list!")
|
||||||
time_.sleep(10)
|
time_.sleep(10)
|
|
@ -1,4 +0,0 @@
|
||||||
# !/usr/bin/python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright @2024, QiYuan Inc
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def rand(n: int, r: random.Random):
|
|
||||||
return int(r.random() * n)
|
|
||||||
|
|
||||||
def transform(data, num_sample: int, r: random.Random):
|
|
||||||
if 'input' in data:
|
|
||||||
_input = "<用户>"+data['input']+"<AI>"
|
|
||||||
else:
|
|
||||||
_input = ""
|
|
||||||
|
|
||||||
if 'output' in data:
|
|
||||||
_output = data['output']
|
|
||||||
else:
|
|
||||||
_output = ""
|
|
||||||
return {"input": _input,
|
|
||||||
"output": _output,
|
|
||||||
}
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def rand(n: int, r: random.Random):
|
|
||||||
return int(r.random() * n)
|
|
||||||
|
|
||||||
def transform(data, num_sample: int, r: random.Random):
|
|
||||||
if 'input' in data:
|
|
||||||
_input = data['input']
|
|
||||||
else:
|
|
||||||
_input = ""
|
|
||||||
|
|
||||||
if 'output' in data:
|
|
||||||
_output = data['output']
|
|
||||||
else:
|
|
||||||
_output = ""
|
|
||||||
return {"input": _input,
|
|
||||||
"output": _output,
|
|
||||||
}
|
|
|
@ -1,134 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"dataset_name": "humanevallike_clean_dedup",
|
|
||||||
"task_name": "humanevallike_clean_dedup",
|
|
||||||
"abs_weight": 0.2,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 995339,
|
|
||||||
"ave_tokens_per_line": 100,
|
|
||||||
"total_tokens": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "leetcode_pass_code_0125",
|
|
||||||
"task_name": "leetcode_pass_code_0125",
|
|
||||||
"abs_weight": 0.006,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 10724,
|
|
||||||
"ave_tokens_per_line": 200,
|
|
||||||
"total_tokens": 0.002
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "logiv2Annotate",
|
|
||||||
"task_name": "logiv2Annotate",
|
|
||||||
"abs_weight": 0.004,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 12566,
|
|
||||||
"ave_tokens_per_line": 512,
|
|
||||||
"total_tokens": 0.006
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "mmlu_enhance",
|
|
||||||
"task_name": "mmlu_enhance",
|
|
||||||
"abs_weight": 0.1,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 169771,
|
|
||||||
"ave_tokens_per_line": 300,
|
|
||||||
"total_tokens": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "mtbench_like",
|
|
||||||
"task_name": "mtbench_like",
|
|
||||||
"abs_weight": 0.2,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 319080,
|
|
||||||
"ave_tokens_per_line": 500,
|
|
||||||
"total_tokens": 0.15
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "ultra_dataset_new",
|
|
||||||
"task_name": "ultra_dataset_new",
|
|
||||||
"abs_weight": 2.0,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 385045,
|
|
||||||
"ave_tokens_per_line": 200.296266559615,
|
|
||||||
"total_tokens": 2.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "sft_data_zh_wowru",
|
|
||||||
"task_name": "sft_data_zh_wowru",
|
|
||||||
"abs_weight": 1.0,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 2963260,
|
|
||||||
"ave_tokens_per_line": 200.296266559615,
|
|
||||||
"total_tokens": 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "math_data",
|
|
||||||
"task_name": "math_data",
|
|
||||||
"abs_weight": 0.003,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
|
||||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 2963260,
|
|
||||||
"ave_tokens_per_line": 200.296266559615,
|
|
||||||
"total_tokens": 0.005
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "t0",
|
|
||||||
"task_name": "t0",
|
|
||||||
"abs_weight": 0.1,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
|
||||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 1650309,
|
|
||||||
"ave_tokens_per_line": 500.296266559615,
|
|
||||||
"total_tokens": 0.82
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "wikihow",
|
|
||||||
"task_name": "wikihow",
|
|
||||||
"abs_weight": 0.1,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 180128,
|
|
||||||
"ave_tokens_per_line": 900.296266559615,
|
|
||||||
"total_tokens": 0.16
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "reclor",
|
|
||||||
"task_name": "reclor",
|
|
||||||
"abs_weight": 0.002,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 4174,
|
|
||||||
"ave_tokens_per_line": 700.296266559615,
|
|
||||||
"total_tokens": 0.003
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "logic_test_lx_0127",
|
|
||||||
"task_name": "logic_test_lx_0127",
|
|
||||||
"abs_weight": 0.001,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
|
||||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 2800,
|
|
||||||
"ave_tokens_per_line": 200.96266559615,
|
|
||||||
"total_tokens": 0.0004
|
|
||||||
}
|
|
||||||
]
|
|
|
@ -1,28 +0,0 @@
|
||||||
{
|
|
||||||
"vocab_size": 122753,
|
|
||||||
"dropout_p": 0.0,
|
|
||||||
"eps": 1e-05,
|
|
||||||
"half": true,
|
|
||||||
"half_type": "bf16",
|
|
||||||
"use_flash_attn": true,
|
|
||||||
"flash_attn_mask_shape": "2d",
|
|
||||||
"dim_model": 2304,
|
|
||||||
"dim_ff": 5760,
|
|
||||||
"dim_head": 64,
|
|
||||||
"num_heads": 36,
|
|
||||||
"num_kv_heads": 36,
|
|
||||||
"num_layers": 40,
|
|
||||||
"activate_fn": "silu",
|
|
||||||
"init_std": 0.10,
|
|
||||||
"scale": true,
|
|
||||||
"scale_emb": 12,
|
|
||||||
"scale_depth": 1.4,
|
|
||||||
"dim_model_base": 256,
|
|
||||||
"model_type": "fm9g",
|
|
||||||
"architectures": [
|
|
||||||
"FM9GForCausalLM"
|
|
||||||
],
|
|
||||||
"qk_norm": false,
|
|
||||||
"tie_lm_head": true,
|
|
||||||
"ffn_gated": true
|
|
||||||
}
|
|
|
@ -1,548 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 QiYuan Inc.
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
from itertools import chain
|
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from bmtrain import nccl
|
|
||||||
from bmtrain.global_var import config as bmt_config
|
|
||||||
|
|
||||||
sys.path.append("../../")
|
|
||||||
from fm9g.arguments import get_args
|
|
||||||
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
|
||||||
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
|
||||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
|
|
||||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
|
|
||||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
|
|
||||||
from fm9g.utils import exporter
|
|
||||||
from fm9g.utils import logger
|
|
||||||
from fm9g.utils.exporter import save_every_step_stats
|
|
||||||
from fm9g.utils.training_stats import num_non_embedding_parameters
|
|
||||||
from fm9g.utils.training_stats import num_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(args):
|
|
||||||
from transformers import LlamaTokenizerFast
|
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast(vocab_file=args.tokenizer_path)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(args):
|
|
||||||
config = DragonflyConfig.from_json_file(args.model_config)
|
|
||||||
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
|
||||||
config.pose_prob = args.pose_prob
|
|
||||||
config.pose_scaling_factor = args.pose_scaling_factor
|
|
||||||
config.rope_scaling_type = args.rope_scaling_type
|
|
||||||
config.rope_scaling_factor = args.rope_scaling_factor
|
|
||||||
config.orig_max_length = args.orig_max_length
|
|
||||||
|
|
||||||
bmt.print_rank("model config: {}".format(config))
|
|
||||||
bmt.print_rank("bmt config: {}".format(bmt.config))
|
|
||||||
|
|
||||||
model = Dragonfly(config)
|
|
||||||
if args.load is not None:
|
|
||||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
|
||||||
exporter.load_model_ckpt(args, model)
|
|
||||||
else:
|
|
||||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
|
||||||
bmt.init_parameters(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, model):
|
|
||||||
scale_lr_group = []
|
|
||||||
normal_group = []
|
|
||||||
scale_lr_group_name, normal_group_name = [], []
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
|
||||||
scale_lr_group.append(p)
|
|
||||||
scale_lr_group_name.append(n)
|
|
||||||
else:
|
|
||||||
normal_group.append(p)
|
|
||||||
normal_group_name.append(n)
|
|
||||||
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
|
||||||
param_groups = [
|
|
||||||
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
|
||||||
{"params": normal_group, "lr": args.lr},
|
|
||||||
]
|
|
||||||
|
|
||||||
if args.offload:
|
|
||||||
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
|
||||||
else:
|
|
||||||
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
|
||||||
if args.load is not None and args.load_grad:
|
|
||||||
exporter.load_optimizer_ckpt(args, optimizer)
|
|
||||||
bmt.print_rank("optimizer is loaded!")
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_learning_rate_scheduler(args, optimizer):
|
|
||||||
from fm9g.training_utils.lr_scheduler import Cosine
|
|
||||||
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
|
||||||
|
|
||||||
end_iter = args.train_iters
|
|
||||||
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
|
||||||
warmup_iters = int(end_iter * args.warmup_iters)
|
|
||||||
else:
|
|
||||||
warmup_iters = int(args.warmup_iters)
|
|
||||||
|
|
||||||
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
|
||||||
drop_iters = int(end_iter * args.drop_iters)
|
|
||||||
else:
|
|
||||||
drop_iters = int(args.drop_iters)
|
|
||||||
|
|
||||||
if args.lr_scheduler == "cosine":
|
|
||||||
lr_scheduler = Cosine(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=warmup_iters,
|
|
||||||
end_iter=end_iter, # 原来是lr_decay_iter
|
|
||||||
num_iter=args.start_step,
|
|
||||||
#lr_end_restart=args.lr_end_restart,
|
|
||||||
#resume_no_optimze=args.resume_no_optimze,
|
|
||||||
)
|
|
||||||
elif args.lr_scheduler == "warmupstabledrop":
|
|
||||||
lr_scheduler = WarmupStableDrop(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=warmup_iters,
|
|
||||||
end_iter=end_iter, # 原来是lr_decay_iter
|
|
||||||
drop_iter=drop_iters,
|
|
||||||
num_iter=args.start_step,
|
|
||||||
resume_no_optimze=args.resume_no_optimze,
|
|
||||||
)
|
|
||||||
return lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_optimizer(args):
|
|
||||||
start = time.time()
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
model = get_model(args)
|
|
||||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
optimizer = get_optimizer(args, model)
|
|
||||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
return tokenizer, model, optimizer, lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def resume_training(args):
|
|
||||||
ckpts = sorted(
|
|
||||||
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
|
||||||
reverse=True,
|
|
||||||
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
|
||||||
)
|
|
||||||
# find newest job
|
|
||||||
ckpts = sorted(
|
|
||||||
ckpts,
|
|
||||||
reverse=True,
|
|
||||||
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(ckpts) > 0:
|
|
||||||
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
|
||||||
args.load = ckpts[0]
|
|
||||||
# by default, do not load grad file
|
|
||||||
args.load_grad = False
|
|
||||||
args.start_step = 0
|
|
||||||
else:
|
|
||||||
# no ckpts, nothing we can do
|
|
||||||
os._exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
|
||||||
args = get_args(pretrain=True)
|
|
||||||
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
|
||||||
|
|
||||||
if args.save is not None:
|
|
||||||
os.makedirs(args.save, exist_ok=True)
|
|
||||||
if args.load is not None:
|
|
||||||
if args.only_load_model == 0:
|
|
||||||
if args.start_step == 0:
|
|
||||||
log_ckpt = exporter.load_log_ckpt(args)
|
|
||||||
if "iteration" in log_ckpt:
|
|
||||||
args.start_step = log_ckpt["iteration"]
|
|
||||||
else:
|
|
||||||
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
|
||||||
logger.info("Start from step {}".format(args.start_step))
|
|
||||||
elif args.only_load_model == 1:
|
|
||||||
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
logger.info("You do not load model")
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def see_memory(detail=False):
|
|
||||||
if detail:
|
|
||||||
res = torch.cuda.memory_summary()
|
|
||||||
else:
|
|
||||||
res = (
|
|
||||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
|
||||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
|
||||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
|
||||||
)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def add_mem_time(info, mem_usage, tim_usage):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
bmt.synchronize()
|
|
||||||
mem_usage[info] = see_memory()
|
|
||||||
tim_usage[info] = time.time()
|
|
||||||
return mem_usage, tim_usage
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
|
||||||
# task_ids 可能有-1 来代表无效token
|
|
||||||
_task_num = task_num + 1
|
|
||||||
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
|
||||||
# gen masks
|
|
||||||
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
|
||||||
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
|
||||||
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
|
||||||
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
|
||||||
# calc loss and tokens
|
|
||||||
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
|
||||||
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
|
||||||
# return token-wise avg losses and tokens
|
|
||||||
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkAve:
|
|
||||||
def __init__(self, chunk_size=100):
|
|
||||||
self.ave_list = []
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
|
|
||||||
def record(self, time):
|
|
||||||
self.ave_list.append(time)
|
|
||||||
self.ave_list = self.ave_list[-self.chunk_size :]
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
return sum(self.ave_list) / len(self.ave_list)
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain(
|
|
||||||
args,
|
|
||||||
tokenizer,
|
|
||||||
model: Dragonfly,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
|
||||||
):
|
|
||||||
ave_model_time = ChunkAve(chunk_size=100)
|
|
||||||
ave_iter_time = ChunkAve(chunk_size=100)
|
|
||||||
|
|
||||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
|
||||||
optim_manager = bmt.optim.OptimManager(
|
|
||||||
loss_scale=None,
|
|
||||||
loss_scale_steps=args.loss_scale_steps,
|
|
||||||
loss_scale_factor=2,
|
|
||||||
max_loss_scale=args.max_loss_scale,
|
|
||||||
min_loss_scale=args.min_loss_scale,
|
|
||||||
)
|
|
||||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
start_step = args.start_step
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
import distutils.version # noqa: F401
|
|
||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
if not os.path.exists(args.tensorboard):
|
|
||||||
os.makedirs(args.tensorboard)
|
|
||||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
|
||||||
|
|
||||||
if args.load is not None:
|
|
||||||
log_ckpt = exporter.load_log_ckpt(args)
|
|
||||||
else:
|
|
||||||
log_ckpt = {}
|
|
||||||
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
|
||||||
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
|
||||||
|
|
||||||
global_world_size = bmt.world_size()
|
|
||||||
bmt.print_rank("Begin preparing dataset")
|
|
||||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
|
||||||
mixed_indexed_dataset = MixedIndexedDataset(
|
|
||||||
cfg_path=args.dataset,
|
|
||||||
cfg_json_str=None,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_length=args.max_length,
|
|
||||||
nthreads=args.dataloader_num_threads,
|
|
||||||
prefetch_slice=args.dataloader_prefetch,
|
|
||||||
weight_by_size=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
|
||||||
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
|
||||||
|
|
||||||
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
batched_dataset,
|
|
||||||
batch_size=None,
|
|
||||||
collate_fn=lambda x: x,
|
|
||||||
num_workers=args.dataloader_num_workers,
|
|
||||||
prefetch_factor=args.dataloader_prefetch_factor,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
|
|
||||||
def dummy_generator():
|
|
||||||
while True:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
mixed_indexed_dataset = dummy_generator()
|
|
||||||
dataloader = mixed_indexed_dataset
|
|
||||||
|
|
||||||
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
|
||||||
|
|
||||||
bmt.print_rank("Preparing dataset done.")
|
|
||||||
|
|
||||||
# inspect at init
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
|
|
||||||
try:
|
|
||||||
mem_usage, tim_usage = {}, {}
|
|
||||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
|
||||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
|
||||||
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
|
||||||
|
|
||||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
logits = model(
|
|
||||||
input=data["inputs"],
|
|
||||||
cu_seqlens=data["cu_seqlens"],
|
|
||||||
max_seqlen=data["max_seqlen"],
|
|
||||||
position_ids=data["position_ids"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# chunk targets and task_ids
|
|
||||||
data["targets"] = (
|
|
||||||
data["targets"]
|
|
||||||
.view(-1)
|
|
||||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
|
||||||
.view(data["targets"].shape[0], -1)
|
|
||||||
)
|
|
||||||
data["task_ids"] = (
|
|
||||||
data["task_ids"]
|
|
||||||
.view(-1)
|
|
||||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
|
||||||
.view(data["task_ids"].shape[0], -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
_target = data["targets"].view(-1)
|
|
||||||
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
|
||||||
_w = (_target != -100).int()
|
|
||||||
loss = non_reduced_loss.sum() / _w.sum().float()
|
|
||||||
|
|
||||||
global_loss = bmt.sum_loss(loss).item()
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
optim_manager.backward(loss)
|
|
||||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
|
||||||
grad_accum_init_time = tim_usage["init"]
|
|
||||||
|
|
||||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
|
||||||
optim_manager.step()
|
|
||||||
optim_manager.zero_grad()
|
|
||||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
|
||||||
model_time = tim_usage["optim"] - grad_accum_init_time
|
|
||||||
ave_model_time.record(model_time)
|
|
||||||
else:
|
|
||||||
# dummy optim step
|
|
||||||
grad_norm = torch.Tensor([0.0]).cuda()
|
|
||||||
tim_usage["optim"] = tim_usage["backward"]
|
|
||||||
mem_usage["optim"] = mem_usage["backward"]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
task_num = len(data["task_names"])
|
|
||||||
task_loss, task_token = get_task_loss_and_token(
|
|
||||||
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
|
||||||
)
|
|
||||||
task_loss_map: Dict[str, float] = {}
|
|
||||||
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
|
||||||
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
|
||||||
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
|
||||||
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
|
||||||
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
|
||||||
ave_task_loss = sum_task_loss / tot_task_token
|
|
||||||
for i in range(task_num):
|
|
||||||
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
|
||||||
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
|
||||||
|
|
||||||
local_total_rate = torch.Tensor([data["lengths"].float().mean() / args.max_length]).cuda()
|
|
||||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
|
||||||
global_token_pass += (
|
|
||||||
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|
||||||
)
|
|
||||||
last_before_log_time = tim_usage["before_log"]
|
|
||||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
iter_time = tim_usage["before_log"] - last_before_log_time
|
|
||||||
|
|
||||||
ave_iter_time.record(iter_time)
|
|
||||||
|
|
||||||
train_info = {
|
|
||||||
"time": iter_time,
|
|
||||||
"iteration": iteration,
|
|
||||||
"loss": global_loss,
|
|
||||||
"lr": lr_scheduler.current_lr,
|
|
||||||
"token_max": local_total_rate,
|
|
||||||
"token_pass": global_token_pass,
|
|
||||||
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
|
||||||
"grad_norm": grad_norm.item(),
|
|
||||||
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
|
||||||
"task_loss": task_loss_map,
|
|
||||||
"total_task_token": global_total_task_token,
|
|
||||||
}
|
|
||||||
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
(
|
|
||||||
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time:.2f} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
|
||||||
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
|
||||||
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
|
||||||
).format(
|
|
||||||
iteration=iteration,
|
|
||||||
loss=global_loss,
|
|
||||||
lr=lr_scheduler.current_lr,
|
|
||||||
model_time=model_time,
|
|
||||||
iter_time=iter_time,
|
|
||||||
chunk_ave_time=ave_iter_time.get(),
|
|
||||||
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
|
||||||
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
|
||||||
grad_norm=grad_norm.item(),
|
|
||||||
global_token_pass=global_token_pass_str,
|
|
||||||
mem_usage=max([value for key, value in mem_usage.items()]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"task_loss:\t| "
|
|
||||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
|
||||||
+ " |"
|
|
||||||
)
|
|
||||||
|
|
||||||
if iteration % 10 == 0:
|
|
||||||
bmt.print_rank(
|
|
||||||
"task_tokens (B):\t| "
|
|
||||||
+ " | ".join(
|
|
||||||
[
|
|
||||||
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
|
||||||
for task_name, task_token in global_total_task_token.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
+ " |"
|
|
||||||
)
|
|
||||||
|
|
||||||
if iteration % args.inspect_iters == 0:
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
|
|
||||||
if args.log_dir is not None and bmt.rank() == 0:
|
|
||||||
if args.save is not None:
|
|
||||||
save_every_step_stats(train_info, args.save)
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
|
||||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
|
||||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
|
||||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
|
||||||
for task_name, loss in task_loss_map.items():
|
|
||||||
if not math.isnan(loss):
|
|
||||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
|
||||||
|
|
||||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
|
||||||
log_ckpt = {
|
|
||||||
"global_total_task_token": global_total_task_token,
|
|
||||||
"global_token_pass": global_token_pass,
|
|
||||||
"iteration": iteration,
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.save is not None and iteration % args.save_iters == 0:
|
|
||||||
exporter.export(
|
|
||||||
model,
|
|
||||||
mixed_indexed_dataset,
|
|
||||||
tokenizer,
|
|
||||||
optimizer,
|
|
||||||
iteration,
|
|
||||||
args,
|
|
||||||
log_ckpt=log_ckpt,
|
|
||||||
final_save=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if iteration == args.train_iters and args.stop_when_end == 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"train loop err: {e}")
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_k_and_b(number):
|
|
||||||
if number >= 1e9: # 大于或等于10亿
|
|
||||||
b_number = number / 1e9
|
|
||||||
return f"{b_number:.2f}B"
|
|
||||||
elif number >= 1e6: # 大于或等于1百万
|
|
||||||
k_number = number / 1e6
|
|
||||||
return f"{k_number:.2f}M"
|
|
||||||
elif number >= 1e3:
|
|
||||||
k_number = number / 1e3
|
|
||||||
return f"{k_number:.2f}K"
|
|
||||||
else:
|
|
||||||
return str(number)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = initialize()
|
|
||||||
bmt.synchronize()
|
|
||||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
|
||||||
bmt.print_rank("finish loading")
|
|
||||||
bmt.print_rank(
|
|
||||||
"Number of parameter {}, Number of non-e parameter {}".format(
|
|
||||||
num_parameters(model), num_non_embedding_parameters(model)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
bmt.print_rank("args: {}".format(args))
|
|
||||||
|
|
||||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,234 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
#export OMP_NUM_THREADS=16
|
|
||||||
|
|
||||||
declare -A args # Declare an associative array to store arguments and values
|
|
||||||
|
|
||||||
args["model_unique"]="2b_0701"
|
|
||||||
args["resume_ckpt"]=""
|
|
||||||
args["config"]="2.4b"
|
|
||||||
args["flash"]="cuda"
|
|
||||||
args["batch_size"]="1"
|
|
||||||
args["max_length"]="4096"
|
|
||||||
args["save_iters"]="500"
|
|
||||||
args["train_iters"]="10"
|
|
||||||
args["dataset_config"]="fm9g_sft"
|
|
||||||
args["local"]="False"
|
|
||||||
args["dataloader"]="indexed"
|
|
||||||
args["save"]="True"
|
|
||||||
args["dataloader_num_threads"]=1
|
|
||||||
args["dataloader_prefetch"]=1
|
|
||||||
args["dataloader_prefetch_factor"]=1
|
|
||||||
args["dataloader_num_workers"]=1
|
|
||||||
args["lr"]="1e-5"
|
|
||||||
args["warmup_iters"]="20"
|
|
||||||
args["drop_iters"]="0.1"
|
|
||||||
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
|
||||||
args["load_grad"]="False"
|
|
||||||
args["grad_ckpt_num"]="160"
|
|
||||||
args["exp_group"]=""
|
|
||||||
args["ignore_cuda_oom"]="1"
|
|
||||||
args["tensorboard_all_tasks"]="0"
|
|
||||||
args["stop_when_end"]="0"
|
|
||||||
args["only_run_dataloader"]="0"
|
|
||||||
args["eps"]="1e-6"
|
|
||||||
args["inspect_iters"]="100"
|
|
||||||
args["strict_state_dict"]="1"
|
|
||||||
args["only_load_model"]="1"
|
|
||||||
args["lr_scheduler"]="cosine"
|
|
||||||
args["resume_no_optimze"]="0"
|
|
||||||
args["tp_size"]="1"
|
|
||||||
args["parallel_load_datastate"]="8"
|
|
||||||
args["async_save"]="False"
|
|
||||||
args["load_dataloader_ckpt"]="0"
|
|
||||||
args["drop_begin"]="-1"
|
|
||||||
args["drop_rate"]="0.5"
|
|
||||||
args["use_checkpoint"]="0"
|
|
||||||
|
|
||||||
|
|
||||||
# Loop through the arguments
|
|
||||||
for ((i=1; i<=$#; i++)); do
|
|
||||||
arg="${!i}"
|
|
||||||
# Check if the argument starts with "--"
|
|
||||||
if [[ "$arg" == --* ]]; then
|
|
||||||
arg_name="${arg:2}" # Remove leading "--"
|
|
||||||
valueid=$((i+1))
|
|
||||||
# Get the value of the argument if it exists
|
|
||||||
if ((i+1 <= $#)); then
|
|
||||||
args["$arg_name"]="${!valueid}"
|
|
||||||
i=$((i+1)) # Skip the next argument (its value)
|
|
||||||
else
|
|
||||||
args["$arg_name"]="" # Set empty value if no value provided
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
|
||||||
while read -r key value; do
|
|
||||||
args["$key"]="$value"
|
|
||||||
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 用cmd arg 再更新一次
|
|
||||||
# Loop through the arguments
|
|
||||||
for ((i=1; i<=$#; i++)); do
|
|
||||||
arg="${!i}"
|
|
||||||
# Check if the argument starts with "--"
|
|
||||||
if [[ "$arg" == --* ]]; then
|
|
||||||
arg_name="${arg:2}" # Remove leading "--"
|
|
||||||
valueid=$((i+1))
|
|
||||||
|
|
||||||
# Get the value of the argument if it exists
|
|
||||||
if ((i+1 <= $#)); then
|
|
||||||
args["$arg_name"]="${!valueid}"
|
|
||||||
i=$((i+1)) # Skip the next argument (its value)
|
|
||||||
else
|
|
||||||
args["$arg_name"]="" # Set empty value if no value provided
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Print the values of the arguments
|
|
||||||
echo "----------- CMD args ----------"
|
|
||||||
for key in "${!args[@]}"; do
|
|
||||||
echo "$key: ${args[$key]}"
|
|
||||||
done
|
|
||||||
echo "--------- END CMD args --------"
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["flash"]} == "triton" ]]; then
|
|
||||||
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
|
||||||
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
|
||||||
echo "triton flash"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
|
||||||
# GPUS_PER_NODE=1
|
|
||||||
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["model_unique"]} == "" ]]; then
|
|
||||||
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
|
||||||
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
|
||||||
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
|
||||||
else
|
|
||||||
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
|
||||||
fi
|
|
||||||
echo "model_unique: "$MODEL_UNIQUE
|
|
||||||
|
|
||||||
# --------------- 运行参数 ---------------
|
|
||||||
|
|
||||||
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
|
||||||
OPTS+=" --batch-size ${args["batch_size"]}"
|
|
||||||
OPTS+=" --train-iters ${args["train_iters"]}"
|
|
||||||
OPTS+=" --save-iters ${args["save_iters"]}"
|
|
||||||
OPTS+=" --save-name fm9g_live_checkpoint"
|
|
||||||
OPTS+=" --max-length ${args["max_length"]}"
|
|
||||||
OPTS+=" --lr ${args["lr"]}"
|
|
||||||
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
|
||||||
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
|
||||||
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
|
||||||
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
|
||||||
OPTS+=" --offload"
|
|
||||||
#OPTS+=" --vocab ./tokenizer/vocab.txt"
|
|
||||||
OPTS+=" --flash ${args["flash"]}"
|
|
||||||
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
|
||||||
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
|
||||||
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
|
||||||
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
|
||||||
OPTS+=" --eps ${args["eps"]}"
|
|
||||||
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
|
||||||
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
|
||||||
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
|
||||||
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
|
||||||
OPTS+=" --weight-decay 0.1"
|
|
||||||
OPTS+=" --tp-size ${args["tp_size"]}"
|
|
||||||
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
|
||||||
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
|
||||||
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
|
||||||
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
|
||||||
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
|
||||||
|
|
||||||
if [[ ${args["load_grad"]} == "True" ]]; then
|
|
||||||
OPTS+=" --load-grad"
|
|
||||||
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["async_save"]} == "True" ]]; then
|
|
||||||
OPTS+=" --async_save"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
|
||||||
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
|
||||||
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
|
||||||
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
|
||||||
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# --------------- 写文件路径 ---------------
|
|
||||||
## checkpoint
|
|
||||||
if [[ ${args["save"]} == "True" ]]; then
|
|
||||||
|
|
||||||
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
|
||||||
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
|
||||||
else
|
|
||||||
echo "won't save model"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
## logs,/local/logs 等价于 ./datalogs(软链)
|
|
||||||
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
|
||||||
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
|
||||||
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["local"]} == "True" ]]; then
|
|
||||||
current_dir=$(pwd)
|
|
||||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
|
||||||
else
|
|
||||||
current_dir=$(pwd)
|
|
||||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
|
||||||
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
|
||||||
if [ "${args["resume_ckpt"]}" != "" ]; then
|
|
||||||
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
|
||||||
else
|
|
||||||
echo "No checkpoint to load"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
filename="pretrain_dragonfly"
|
|
||||||
|
|
||||||
if [[ ${args["local"]} == "True" ]]; then
|
|
||||||
PRETRAIN_ENTRY="$filename.py"
|
|
||||||
else
|
|
||||||
PRETRAIN_ENTRY="$filename.py"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
GPUS_PER_NODE=1
|
|
||||||
NNODES=1
|
|
||||||
RANK=0
|
|
||||||
MASTER_ENDPOINT=g3006
|
|
||||||
MASTER_PORT=23456
|
|
||||||
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
|
||||||
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
|
||||||
|
|
||||||
echo "-------final CMD is------"
|
|
||||||
echo "${CMD}"
|
|
||||||
echo "-------final CMD end------"
|
|
||||||
|
|
||||||
$CMD
|
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
@ -1,9 +0,0 @@
|
||||||
{
|
|
||||||
"pretrain": {
|
|
||||||
"train_iters": 1000000000,
|
|
||||||
"batch_size": 1,
|
|
||||||
"max_length": 4096,
|
|
||||||
"n_gpus": 8,
|
|
||||||
"lr": 0.01
|
|
||||||
}
|
|
||||||
}
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def rand(n: int, r: random.Random):
|
|
||||||
return int(r.random() * n)
|
|
||||||
|
|
||||||
def transform(data, num_sample: int, r: random.Random):
|
|
||||||
if 'input' in data:
|
|
||||||
_input = "<用户>"+data['input']+"<AI>"
|
|
||||||
else:
|
|
||||||
_input = ""
|
|
||||||
|
|
||||||
if 'output' in data:
|
|
||||||
_output = data['output']
|
|
||||||
else:
|
|
||||||
_output = ""
|
|
||||||
return {"input": _input,
|
|
||||||
"output": _output,
|
|
||||||
}
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
def rand(n: int, r: random.Random):
|
|
||||||
return int(r.random() * n)
|
|
||||||
|
|
||||||
def transform(data, num_sample: int, r: random.Random):
|
|
||||||
if 'input' in data:
|
|
||||||
_input = data['input']
|
|
||||||
else:
|
|
||||||
_input = ""
|
|
||||||
|
|
||||||
if 'output' in data:
|
|
||||||
_output = data['output']
|
|
||||||
else:
|
|
||||||
_output = ""
|
|
||||||
return {"input": _input,
|
|
||||||
"output": _output,
|
|
||||||
}
|
|
|
@ -1,134 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"dataset_name": "humanevallike_clean_dedup",
|
|
||||||
"task_name": "humanevallike_clean_dedup",
|
|
||||||
"abs_weight": 0.2,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 995339,
|
|
||||||
"ave_tokens_per_line": 100,
|
|
||||||
"total_tokens": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "leetcode_pass_code_0125",
|
|
||||||
"task_name": "leetcode_pass_code_0125",
|
|
||||||
"abs_weight": 0.006,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 10724,
|
|
||||||
"ave_tokens_per_line": 200,
|
|
||||||
"total_tokens": 0.002
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "logiv2Annotate",
|
|
||||||
"task_name": "logiv2Annotate",
|
|
||||||
"abs_weight": 0.004,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 12566,
|
|
||||||
"ave_tokens_per_line": 512,
|
|
||||||
"total_tokens": 0.006
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "mmlu_enhance",
|
|
||||||
"task_name": "mmlu_enhance",
|
|
||||||
"abs_weight": 0.1,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 169771,
|
|
||||||
"ave_tokens_per_line": 300,
|
|
||||||
"total_tokens": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "mtbench_like",
|
|
||||||
"task_name": "mtbench_like",
|
|
||||||
"abs_weight": 0.2,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 319080,
|
|
||||||
"ave_tokens_per_line": 500,
|
|
||||||
"total_tokens": 0.15
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "ultra_dataset_new",
|
|
||||||
"task_name": "ultra_dataset_new",
|
|
||||||
"abs_weight": 2.0,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 385045,
|
|
||||||
"ave_tokens_per_line": 200.296266559615,
|
|
||||||
"total_tokens": 2.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "sft_data_zh_wowru",
|
|
||||||
"task_name": "sft_data_zh_wowru",
|
|
||||||
"abs_weight": 1.0,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 2963260,
|
|
||||||
"ave_tokens_per_line": 200.296266559615,
|
|
||||||
"total_tokens": 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "math_data",
|
|
||||||
"task_name": "math_data",
|
|
||||||
"abs_weight": 0.003,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
|
||||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 2963260,
|
|
||||||
"ave_tokens_per_line": 200.296266559615,
|
|
||||||
"total_tokens": 0.005
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "t0",
|
|
||||||
"task_name": "t0",
|
|
||||||
"abs_weight": 0.1,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
|
||||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 1650309,
|
|
||||||
"ave_tokens_per_line": 500.296266559615,
|
|
||||||
"total_tokens": 0.82
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "wikihow",
|
|
||||||
"task_name": "wikihow",
|
|
||||||
"abs_weight": 0.1,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 180128,
|
|
||||||
"ave_tokens_per_line": 900.296266559615,
|
|
||||||
"total_tokens": 0.16
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "reclor",
|
|
||||||
"task_name": "reclor",
|
|
||||||
"abs_weight": 0.002,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
|
||||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 4174,
|
|
||||||
"ave_tokens_per_line": 700.296266559615,
|
|
||||||
"total_tokens": 0.003
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataset_name": "logic_test_lx_0127",
|
|
||||||
"task_name": "logic_test_lx_0127",
|
|
||||||
"abs_weight": 0.001,
|
|
||||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
|
||||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
|
||||||
"allow_repeat": true,
|
|
||||||
"nlines": 2800,
|
|
||||||
"ave_tokens_per_line": 200.96266559615,
|
|
||||||
"total_tokens": 0.0004
|
|
||||||
}
|
|
||||||
]
|
|
|
@ -1,27 +0,0 @@
|
||||||
{
|
|
||||||
"vocab_size": 119696,
|
|
||||||
"dropout_p": 0.0,
|
|
||||||
"eps": 1e-05,
|
|
||||||
"half": true,
|
|
||||||
"half_type": "bf16",
|
|
||||||
"use_flash_attn": true,
|
|
||||||
"flash_attn_mask_shape": "2d",
|
|
||||||
"dim_model": 4096,
|
|
||||||
"dim_ff": 14336,
|
|
||||||
"dim_head": 128,
|
|
||||||
"num_heads": 32,
|
|
||||||
"num_kv_heads": 32,
|
|
||||||
"num_layers": 32,
|
|
||||||
"activate_fn": "silu",
|
|
||||||
"init_std": 0.10,
|
|
||||||
"scale": false,
|
|
||||||
"scale_emb": 12,
|
|
||||||
"scale_depth": -1,
|
|
||||||
"model_type": "fm9g",
|
|
||||||
"architectures": [
|
|
||||||
"FM9GForCausalLM"
|
|
||||||
],
|
|
||||||
"qk_norm": false,
|
|
||||||
"tie_lm_head": false,
|
|
||||||
"ffn_gated": true
|
|
||||||
}
|
|
|
@ -1,568 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 ModelBest Inc.
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
from itertools import chain
|
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import bmtrain as bmt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from bmtrain import nccl
|
|
||||||
from bmtrain.global_var import config as bmt_config
|
|
||||||
|
|
||||||
sys.path.append("../../")
|
|
||||||
from fm9g.arguments import get_args
|
|
||||||
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
|
||||||
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
|
||||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import CudaPrefetcher
|
|
||||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import MixedIndexedDataset
|
|
||||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import UnpadBatchedMixedDataset
|
|
||||||
from fm9g.utils import exporter
|
|
||||||
from fm9g.utils import logger
|
|
||||||
from fm9g.utils.exporter import save_every_step_stats
|
|
||||||
from fm9g.utils.training_stats import num_non_embedding_parameters
|
|
||||||
from fm9g.utils.training_stats import num_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(args):
|
|
||||||
from fm9g.tokenizer import FM9GTokenizer
|
|
||||||
tokenizer = FM9GTokenizer(path=args.vocab)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(args):
|
|
||||||
config = DragonflyConfig.from_json_file(args.model_config)
|
|
||||||
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
|
||||||
config.pose_prob = args.pose_prob
|
|
||||||
config.pose_scaling_factor = args.pose_scaling_factor
|
|
||||||
config.rope_scaling_type = args.rope_scaling_type
|
|
||||||
config.rope_scaling_factor = args.rope_scaling_factor
|
|
||||||
config.orig_max_length = args.orig_max_length
|
|
||||||
config.use_checkpoint = True if args.use_checkpoint == 1 else False
|
|
||||||
|
|
||||||
bmt.print_rank("model config: {}".format(config))
|
|
||||||
bmt.print_rank("bmt config: {}".format(bmt.config))
|
|
||||||
|
|
||||||
model = Dragonfly(config)
|
|
||||||
if args.load is not None:
|
|
||||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
|
||||||
exporter.load_model_ckpt(args, model)
|
|
||||||
else:
|
|
||||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
|
||||||
bmt.init_parameters(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, model):
|
|
||||||
scale_lr_group = []
|
|
||||||
normal_group = []
|
|
||||||
scale_lr_group_name, normal_group_name = [], []
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
|
||||||
scale_lr_group.append(p)
|
|
||||||
scale_lr_group_name.append(n)
|
|
||||||
else:
|
|
||||||
normal_group.append(p)
|
|
||||||
normal_group_name.append(n)
|
|
||||||
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
|
||||||
param_groups = [
|
|
||||||
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
|
||||||
{"params": normal_group, "lr": args.lr},
|
|
||||||
]
|
|
||||||
|
|
||||||
if args.offload:
|
|
||||||
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
|
||||||
else:
|
|
||||||
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
|
||||||
if args.load is not None and args.load_grad:
|
|
||||||
exporter.load_optimizer_ckpt(args, optimizer)
|
|
||||||
bmt.print_rank("optimizer is loaded!")
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_learning_rate_scheduler(args, optimizer):
|
|
||||||
from fm9g.training_utils.lr_scheduler import Cosine
|
|
||||||
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
|
||||||
from fm9g.training_utils.lr_scheduler import WarmupStableExp
|
|
||||||
|
|
||||||
end_iter = args.train_iters
|
|
||||||
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
|
||||||
warmup_iters = int(end_iter * args.warmup_iters)
|
|
||||||
else:
|
|
||||||
warmup_iters = int(args.warmup_iters)
|
|
||||||
|
|
||||||
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
|
||||||
drop_iters = int(end_iter * args.drop_iters)
|
|
||||||
else:
|
|
||||||
drop_iters = int(args.drop_iters)
|
|
||||||
|
|
||||||
if args.lr_scheduler == "cosine":
|
|
||||||
lr_scheduler = Cosine(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=warmup_iters,
|
|
||||||
end_iter=end_iter, # 原来是lr_decay_iter
|
|
||||||
num_iter=args.start_step,
|
|
||||||
)
|
|
||||||
# lr_end_restart=args.lr_end_restart,
|
|
||||||
# resume_no_optimze=args.resume_no_optimze,
|
|
||||||
#)
|
|
||||||
elif args.lr_scheduler == "warmupstabledrop":
|
|
||||||
lr_scheduler = WarmupStableDrop(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=warmup_iters,
|
|
||||||
end_iter=end_iter, # 原来是lr_decay_iter
|
|
||||||
drop_iter=drop_iters,
|
|
||||||
num_iter=args.start_step,
|
|
||||||
resume_no_optimze=args.resume_no_optimze,
|
|
||||||
)
|
|
||||||
elif args.lr_scheduler == "warmupstableexp":
|
|
||||||
lr_scheduler = WarmupStableExp(
|
|
||||||
optimizer,
|
|
||||||
start_lr=args.lr,
|
|
||||||
warmup_iter=warmup_iters,
|
|
||||||
drop_begin=args.drop_begin, # 原来是lr_decay_iter
|
|
||||||
drop_iter=drop_iters,
|
|
||||||
drop_rate=args.drop_rate,
|
|
||||||
num_iter=args.start_step,
|
|
||||||
resume_no_optimze=args.resume_no_optimze,
|
|
||||||
)
|
|
||||||
return lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_optimizer(args):
|
|
||||||
start = time.time()
|
|
||||||
tokenizer = get_tokenizer(args)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
model = get_model(args)
|
|
||||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
optimizer = get_optimizer(args, model)
|
|
||||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
|
||||||
bmt.synchronize()
|
|
||||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
|
||||||
|
|
||||||
return tokenizer, model, optimizer, lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def resume_training(args):
|
|
||||||
ckpts = sorted(
|
|
||||||
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
|
||||||
reverse=True,
|
|
||||||
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
|
||||||
)
|
|
||||||
# find newest job
|
|
||||||
ckpts = sorted(
|
|
||||||
ckpts,
|
|
||||||
reverse=True,
|
|
||||||
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(ckpts) > 0:
|
|
||||||
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
|
||||||
args.load = ckpts[0]
|
|
||||||
# by default, do not load grad file
|
|
||||||
args.load_grad = False
|
|
||||||
args.start_step = 0
|
|
||||||
else:
|
|
||||||
# no ckpts, nothing we can do
|
|
||||||
os._exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
|
||||||
args = get_args(pretrain=True)
|
|
||||||
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
|
||||||
|
|
||||||
if args.save is not None:
|
|
||||||
os.makedirs(args.save, exist_ok=True)
|
|
||||||
if args.load is not None:
|
|
||||||
if args.only_load_model == 0:
|
|
||||||
if args.start_step == 0:
|
|
||||||
log_ckpt = exporter.load_log_ckpt(args)
|
|
||||||
if "iteration" in log_ckpt:
|
|
||||||
args.start_step = log_ckpt["iteration"]
|
|
||||||
else:
|
|
||||||
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
|
||||||
logger.info("Start from step {}".format(args.start_step))
|
|
||||||
elif args.only_load_model == 1:
|
|
||||||
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
logger.info("You do not load model")
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def see_memory(detail=False):
|
|
||||||
if detail:
|
|
||||||
res = torch.cuda.memory_summary()
|
|
||||||
else:
|
|
||||||
res = (
|
|
||||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
|
||||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
|
||||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
|
||||||
)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def add_mem_time(info, mem_usage, tim_usage):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
bmt.synchronize()
|
|
||||||
mem_usage[info] = see_memory()
|
|
||||||
tim_usage[info] = time.time()
|
|
||||||
return mem_usage, tim_usage
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
|
||||||
# task_ids 可能有-1 来代表无效token
|
|
||||||
_task_num = task_num + 1
|
|
||||||
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
|
||||||
# gen masks
|
|
||||||
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
|
||||||
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
|
||||||
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
|
||||||
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
|
||||||
# calc loss and tokens
|
|
||||||
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
|
||||||
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
|
||||||
# return token-wise avg losses and tokens
|
|
||||||
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkAve:
|
|
||||||
def __init__(self, chunk_size=100):
|
|
||||||
self.ave_list = []
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
|
|
||||||
def record(self, time):
|
|
||||||
self.ave_list.append(time)
|
|
||||||
self.ave_list = self.ave_list[-self.chunk_size :]
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
return sum(self.ave_list) / len(self.ave_list)
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain(
|
|
||||||
args,
|
|
||||||
tokenizer,
|
|
||||||
model: Dragonfly,
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
|
||||||
):
|
|
||||||
ave_model_time = ChunkAve(chunk_size=100)
|
|
||||||
ave_iter_time = ChunkAve(chunk_size=100)
|
|
||||||
|
|
||||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
|
||||||
optim_manager = bmt.optim.OptimManager(
|
|
||||||
loss_scale=bmt.world_size(),
|
|
||||||
loss_scale_steps=args.loss_scale_steps,
|
|
||||||
loss_scale_factor=2,
|
|
||||||
max_loss_scale=bmt.world_size(),
|
|
||||||
min_loss_scale=bmt.world_size(),
|
|
||||||
)
|
|
||||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
start_step = args.start_step
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
import distutils.version # noqa: F401
|
|
||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
if not os.path.exists(args.tensorboard):
|
|
||||||
os.makedirs(args.tensorboard)
|
|
||||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
|
||||||
|
|
||||||
if args.load is not None:
|
|
||||||
log_ckpt = exporter.load_log_ckpt(args)
|
|
||||||
else:
|
|
||||||
log_ckpt = {}
|
|
||||||
|
|
||||||
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
|
||||||
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
|
||||||
|
|
||||||
global_world_size = bmt.world_size()
|
|
||||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
|
||||||
mixed_indexed_dataset = MixedIndexedDataset(
|
|
||||||
cfg_path=args.dataset,
|
|
||||||
cfg_json_str=None,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_length=args.max_length,
|
|
||||||
nthreads=args.dataloader_num_threads,
|
|
||||||
prefetch_slice=args.dataloader_prefetch,
|
|
||||||
weight_by_size=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
|
||||||
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
|
||||||
|
|
||||||
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
batched_dataset,
|
|
||||||
batch_size=None,
|
|
||||||
collate_fn=lambda x: x,
|
|
||||||
num_workers=args.dataloader_num_workers,
|
|
||||||
prefetch_factor=args.dataloader_prefetch_factor,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
|
|
||||||
def dummy_generator():
|
|
||||||
while True:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
mixed_indexed_dataset = dummy_generator()
|
|
||||||
dataloader = mixed_indexed_dataset
|
|
||||||
|
|
||||||
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
|
||||||
|
|
||||||
bmt.print_rank("Preparing dataset done.")
|
|
||||||
|
|
||||||
# inspect at init
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
|
|
||||||
try:
|
|
||||||
mem_usage, tim_usage = {}, {}
|
|
||||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
|
||||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
|
||||||
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
|
||||||
|
|
||||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
logits = model(
|
|
||||||
input=data["inputs"],
|
|
||||||
cu_seqlens=data["cu_seqlens"],
|
|
||||||
max_seqlen=data["max_seqlen"],
|
|
||||||
position_ids=data["position_ids"],
|
|
||||||
)
|
|
||||||
#print("logits: ", logits)
|
|
||||||
|
|
||||||
# chunk targets and task_ids
|
|
||||||
data["targets"] = (
|
|
||||||
data["targets"]
|
|
||||||
.view(-1)
|
|
||||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
|
||||||
.view(data["targets"].shape[0], -1)
|
|
||||||
)
|
|
||||||
data["task_ids"] = (
|
|
||||||
data["task_ids"]
|
|
||||||
.view(-1)
|
|
||||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
|
||||||
.view(data["task_ids"].shape[0], -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
_target = data["targets"].view(-1)
|
|
||||||
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
|
||||||
_w = (_target != -100).int()
|
|
||||||
loss = non_reduced_loss.sum() / _w.sum().float()
|
|
||||||
|
|
||||||
global_loss = bmt.sum_loss(loss).item()
|
|
||||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
optim_manager.backward(loss)
|
|
||||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
|
||||||
grad_accum_init_time = tim_usage["init"]
|
|
||||||
|
|
||||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
|
||||||
optim_manager.step()
|
|
||||||
optim_manager.zero_grad()
|
|
||||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
|
||||||
model_time = tim_usage["optim"] - grad_accum_init_time
|
|
||||||
ave_model_time.record(model_time)
|
|
||||||
else:
|
|
||||||
# dummy optim step
|
|
||||||
grad_norm = torch.Tensor([0.0]).cuda()
|
|
||||||
tim_usage["optim"] = tim_usage["backward"]
|
|
||||||
mem_usage["optim"] = mem_usage["backward"]
|
|
||||||
model_time = tim_usage["optim"] - tim_usage['init']
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
task_num = len(data["task_names"])
|
|
||||||
task_loss, task_token = get_task_loss_and_token(
|
|
||||||
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
|
||||||
)
|
|
||||||
task_loss_map: Dict[str, float] = {}
|
|
||||||
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
|
||||||
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
|
||||||
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
|
||||||
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
|
||||||
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
|
||||||
ave_task_loss = sum_task_loss / tot_task_token
|
|
||||||
for i in range(task_num):
|
|
||||||
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
|
||||||
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
|
||||||
|
|
||||||
local_total_rate = torch.Tensor(
|
|
||||||
[data["lengths"].float().mean() / (args.max_length * args.batch_size)]
|
|
||||||
).cuda()
|
|
||||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
|
||||||
global_token_pass += (
|
|
||||||
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|
||||||
)
|
|
||||||
last_before_log_time = tim_usage["before_log"]
|
|
||||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
|
||||||
|
|
||||||
iter_time = tim_usage["before_log"] - last_before_log_time
|
|
||||||
|
|
||||||
ave_iter_time.record(iter_time)
|
|
||||||
|
|
||||||
train_info = {
|
|
||||||
"time": iter_time,
|
|
||||||
"iteration": iteration,
|
|
||||||
"loss": global_loss,
|
|
||||||
"lr": lr_scheduler.current_lr,
|
|
||||||
"token_max": local_total_rate,
|
|
||||||
"token_pass": global_token_pass,
|
|
||||||
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
|
||||||
"grad_norm": grad_norm.item(),
|
|
||||||
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
|
||||||
"task_loss": task_loss_map,
|
|
||||||
"total_task_token": global_total_task_token,
|
|
||||||
}
|
|
||||||
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
|
||||||
|
|
||||||
time_report_str = "{model_time:.2f}={forward_time:.2f}+{backward_time:.2f}+{optim_time:.2f}".format(model_time=model_time, forward_time=tim_usage['forward']-tim_usage['init'], backward_time=tim_usage['backward']-tim_usage['forward'], optim_time=tim_usage['optim'] - tim_usage['backward'])
|
|
||||||
bmt.print_rank(
|
|
||||||
(
|
|
||||||
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
|
||||||
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
|
||||||
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
|
||||||
).format(
|
|
||||||
iteration=iteration,
|
|
||||||
loss=global_loss,
|
|
||||||
lr=lr_scheduler.current_lr,
|
|
||||||
model_time=time_report_str,
|
|
||||||
iter_time=iter_time,
|
|
||||||
chunk_ave_time=ave_iter_time.get(),
|
|
||||||
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
|
||||||
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
|
||||||
grad_norm=grad_norm.item(),
|
|
||||||
global_token_pass=global_token_pass_str,
|
|
||||||
mem_usage=max([value for key, value in mem_usage.items()]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
bmt.print_rank(
|
|
||||||
"task_loss:\t| "
|
|
||||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
|
||||||
+ " |"
|
|
||||||
)
|
|
||||||
|
|
||||||
if iteration % 10 == 0:
|
|
||||||
bmt.print_rank(
|
|
||||||
"task_tokens (B):\t| "
|
|
||||||
+ " | ".join(
|
|
||||||
[
|
|
||||||
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
|
||||||
for task_name, task_token in global_total_task_token.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
+ " |"
|
|
||||||
)
|
|
||||||
|
|
||||||
if iteration % args.inspect_iters == 0:
|
|
||||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
|
||||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
|
||||||
|
|
||||||
if args.log_dir is not None and bmt.rank() == 0:
|
|
||||||
if args.save is not None:
|
|
||||||
save_every_step_stats(train_info, args.save)
|
|
||||||
|
|
||||||
if args.tensorboard is not None and bmt.rank() == 0:
|
|
||||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
|
||||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
|
||||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
|
||||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
|
||||||
for task_name, loss in task_loss_map.items():
|
|
||||||
if not math.isnan(loss):
|
|
||||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
|
||||||
|
|
||||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
|
||||||
log_ckpt = {
|
|
||||||
"global_total_task_token": global_total_task_token,
|
|
||||||
"global_token_pass": global_token_pass,
|
|
||||||
"iteration": iteration,
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.save is not None and iteration % args.save_iters == 0:
|
|
||||||
exporter.export(
|
|
||||||
model,
|
|
||||||
mixed_indexed_dataset,
|
|
||||||
tokenizer,
|
|
||||||
optimizer,
|
|
||||||
iteration,
|
|
||||||
args,
|
|
||||||
log_ckpt=log_ckpt,
|
|
||||||
final_save=False,
|
|
||||||
async_save=args.async_save,
|
|
||||||
)
|
|
||||||
|
|
||||||
if iteration == args.train_iters and args.stop_when_end == 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"train loop err: {e}")
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_k_and_b(number):
|
|
||||||
if number >= 1e9: # 大于或等于10亿
|
|
||||||
b_number = number / 1e9
|
|
||||||
return f"{b_number:.2f}B"
|
|
||||||
elif number >= 1e6: # 大于或等于1百万
|
|
||||||
k_number = number / 1e6
|
|
||||||
return f"{k_number:.2f}M"
|
|
||||||
elif number >= 1e3:
|
|
||||||
k_number = number / 1e3
|
|
||||||
return f"{k_number:.2f}K"
|
|
||||||
else:
|
|
||||||
return str(number)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = initialize()
|
|
||||||
bmt.synchronize()
|
|
||||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
|
||||||
bmt.print_rank("finish loading")
|
|
||||||
bmt.print_rank(
|
|
||||||
"Number of parameter {}, Number of non-e parameter {}".format(
|
|
||||||
num_parameters(model), num_non_embedding_parameters(model)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
bmt.print_rank("args: {}".format(args))
|
|
||||||
|
|
||||||
print("begining training")
|
|
||||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,234 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
#export OMP_NUM_THREADS=16
|
|
||||||
|
|
||||||
declare -A args # Declare an associative array to store arguments and values
|
|
||||||
|
|
||||||
args["model_unique"]="8b_0702"
|
|
||||||
args["resume_ckpt"]=""
|
|
||||||
args["config"]="8b"
|
|
||||||
args["flash"]="cuda"
|
|
||||||
args["batch_size"]="1"
|
|
||||||
args["max_length"]="4096"
|
|
||||||
args["save_iters"]="500"
|
|
||||||
args["train_iters"]="10"
|
|
||||||
args["dataset_config"]="fm9g_sft"
|
|
||||||
args["local"]="False"
|
|
||||||
args["dataloader"]="indexed"
|
|
||||||
args["save"]="True"
|
|
||||||
args["dataloader_num_threads"]=1
|
|
||||||
args["dataloader_prefetch"]=2
|
|
||||||
args["dataloader_prefetch_factor"]=32
|
|
||||||
args["dataloader_num_workers"]=2
|
|
||||||
args["lr"]="1e-5"
|
|
||||||
args["warmup_iters"]="20"
|
|
||||||
args["drop_iters"]="0.1"
|
|
||||||
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
|
||||||
args["load_grad"]="False"
|
|
||||||
args["grad_ckpt_num"]="160"
|
|
||||||
args["exp_group"]=""
|
|
||||||
args["ignore_cuda_oom"]="1"
|
|
||||||
args["tensorboard_all_tasks"]="0"
|
|
||||||
args["stop_when_end"]="0"
|
|
||||||
args["only_run_dataloader"]="0"
|
|
||||||
args["eps"]="1e-6"
|
|
||||||
args["inspect_iters"]="100"
|
|
||||||
args["strict_state_dict"]="1"
|
|
||||||
args["only_load_model"]="1"
|
|
||||||
args["lr_scheduler"]="cosine"
|
|
||||||
args["resume_no_optimze"]="0"
|
|
||||||
args["tp_size"]="1"
|
|
||||||
args["parallel_load_datastate"]="16"
|
|
||||||
args["async_save"]="False"
|
|
||||||
args["load_dataloader_ckpt"]="0"
|
|
||||||
args["drop_begin"]="-1"
|
|
||||||
args["drop_rate"]="0.5"
|
|
||||||
args["use_checkpoint"]="1"
|
|
||||||
|
|
||||||
|
|
||||||
# Loop through the arguments
|
|
||||||
for ((i=1; i<=$#; i++)); do
|
|
||||||
arg="${!i}"
|
|
||||||
# Check if the argument starts with "--"
|
|
||||||
if [[ "$arg" == --* ]]; then
|
|
||||||
arg_name="${arg:2}" # Remove leading "--"
|
|
||||||
valueid=$((i+1))
|
|
||||||
# Get the value of the argument if it exists
|
|
||||||
if ((i+1 <= $#)); then
|
|
||||||
args["$arg_name"]="${!valueid}"
|
|
||||||
i=$((i+1)) # Skip the next argument (its value)
|
|
||||||
else
|
|
||||||
args["$arg_name"]="" # Set empty value if no value provided
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
|
||||||
while read -r key value; do
|
|
||||||
args["$key"]="$value"
|
|
||||||
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 用cmd arg 再更新一次
|
|
||||||
# Loop through the arguments
|
|
||||||
for ((i=1; i<=$#; i++)); do
|
|
||||||
arg="${!i}"
|
|
||||||
# Check if the argument starts with "--"
|
|
||||||
if [[ "$arg" == --* ]]; then
|
|
||||||
arg_name="${arg:2}" # Remove leading "--"
|
|
||||||
valueid=$((i+1))
|
|
||||||
|
|
||||||
# Get the value of the argument if it exists
|
|
||||||
if ((i+1 <= $#)); then
|
|
||||||
args["$arg_name"]="${!valueid}"
|
|
||||||
i=$((i+1)) # Skip the next argument (its value)
|
|
||||||
else
|
|
||||||
args["$arg_name"]="" # Set empty value if no value provided
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Print the values of the arguments
|
|
||||||
echo "----------- CMD args ----------"
|
|
||||||
for key in "${!args[@]}"; do
|
|
||||||
echo "$key: ${args[$key]}"
|
|
||||||
done
|
|
||||||
echo "--------- END CMD args --------"
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["flash"]} == "triton" ]]; then
|
|
||||||
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
|
||||||
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
|
||||||
echo "triton flash"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
|
||||||
# GPUS_PER_NODE=1
|
|
||||||
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["model_unique"]} == "" ]]; then
|
|
||||||
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
|
||||||
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
|
||||||
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
|
||||||
else
|
|
||||||
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
|
||||||
fi
|
|
||||||
echo "model_unique: "$MODEL_UNIQUE
|
|
||||||
|
|
||||||
# --------------- 运行参数 ---------------
|
|
||||||
|
|
||||||
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
|
||||||
OPTS+=" --batch-size ${args["batch_size"]}"
|
|
||||||
OPTS+=" --train-iters ${args["train_iters"]}"
|
|
||||||
OPTS+=" --save-iters ${args["save_iters"]}"
|
|
||||||
OPTS+=" --save-name fm9g_live_checkpoint"
|
|
||||||
OPTS+=" --max-length ${args["max_length"]}"
|
|
||||||
OPTS+=" --lr ${args["lr"]}"
|
|
||||||
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
|
||||||
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
|
||||||
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
|
||||||
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
|
||||||
OPTS+=" --offload"
|
|
||||||
OPTS+=" --vocab ./tokenizer/vocab.txt"
|
|
||||||
OPTS+=" --flash ${args["flash"]}"
|
|
||||||
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
|
||||||
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
|
||||||
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
|
||||||
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
|
||||||
OPTS+=" --eps ${args["eps"]}"
|
|
||||||
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
|
||||||
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
|
||||||
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
|
||||||
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
|
||||||
OPTS+=" --weight-decay 0.1"
|
|
||||||
OPTS+=" --tp-size ${args["tp_size"]}"
|
|
||||||
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
|
||||||
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
|
||||||
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
|
||||||
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
|
||||||
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
|
||||||
|
|
||||||
if [[ ${args["load_grad"]} == "True" ]]; then
|
|
||||||
OPTS+=" --load-grad"
|
|
||||||
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["async_save"]} == "True" ]]; then
|
|
||||||
OPTS+=" --async_save"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
|
||||||
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
|
||||||
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
|
||||||
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
|
||||||
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# --------------- 写文件路径 ---------------
|
|
||||||
## checkpoint
|
|
||||||
if [[ ${args["save"]} == "True" ]]; then
|
|
||||||
|
|
||||||
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
|
||||||
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
|
||||||
else
|
|
||||||
echo "won't save model"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
## logs,/local/logs 等价于 ./datalogs(软链)
|
|
||||||
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
|
||||||
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
|
||||||
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if [[ ${args["local"]} == "True" ]]; then
|
|
||||||
current_dir=$(pwd)
|
|
||||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
|
||||||
else
|
|
||||||
current_dir=$(pwd)
|
|
||||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
|
||||||
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
|
||||||
if [ "${args["resume_ckpt"]}" != "" ]; then
|
|
||||||
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
|
||||||
else
|
|
||||||
echo "No checkpoint to load"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
filename="pretrain_dragonfly"
|
|
||||||
|
|
||||||
if [[ ${args["local"]} == "True" ]]; then
|
|
||||||
PRETRAIN_ENTRY="$filename.py"
|
|
||||||
else
|
|
||||||
PRETRAIN_ENTRY="$filename.py"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
GPUS_PER_NODE=8
|
|
||||||
NNODES=1
|
|
||||||
RANK=0
|
|
||||||
MASTER_ENDPOINT=g3006
|
|
||||||
MASTER_PORT=12345
|
|
||||||
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
|
||||||
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
|
||||||
|
|
||||||
echo "-------final CMD is------"
|
|
||||||
echo "${CMD}"
|
|
||||||
echo "-------final CMD end------"
|
|
||||||
|
|
||||||
$CMD
|
|
|
@ -1,9 +0,0 @@
|
||||||
{
|
|
||||||
"pretrain": {
|
|
||||||
"train_iters": 20000,
|
|
||||||
"batch_size": 1,
|
|
||||||
"max_length": 4096,
|
|
||||||
"n_gpus": 8,
|
|
||||||
"lr": 1e-5
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,7 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
|
||||||
#
|
|
||||||
# @author: ouzebin <ouzebin@zhihu.com>
|
|
||||||
# @date: 2023/08/07
|
|
|
@ -1,8 +0,0 @@
|
||||||
{
|
|
||||||
"folders": [
|
|
||||||
{
|
|
||||||
"path": "../.."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"settings": {}
|
|
||||||
}
|
|
|
@ -1,105 +0,0 @@
|
||||||
import torch
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class DragonflyConfig(PretrainedConfig):
|
|
||||||
model_type = "fm9g_dragonfly"
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
|
||||||
attribute_map = {
|
|
||||||
"num_key_value_heads": "num_kv_heads",
|
|
||||||
"hidden_act": "activate_fn",
|
|
||||||
"hidden_size": "dim_model",
|
|
||||||
"num_attention_heads": "num_heads",
|
|
||||||
"intermediate_size": "dim_ff",
|
|
||||||
"num_hidden_layers": "num_layers",
|
|
||||||
"vocab_size": "vocab_size",
|
|
||||||
"rms_norm_eps": "eps",
|
|
||||||
"scale_emb": "scale_emb",
|
|
||||||
"scale_depth": "scale_depth",
|
|
||||||
"scale": "scale",
|
|
||||||
"attention_scale": "attention_scale",
|
|
||||||
"qk_norm": "qk_norm",
|
|
||||||
"ffn_gated": "ffn_gated",
|
|
||||||
} # model specific to common
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=122753, # TODO: do we need to change to 122880 = 960 * 128?
|
|
||||||
dim_model=4096,
|
|
||||||
num_heads=32,
|
|
||||||
num_kv_heads=32,
|
|
||||||
dim_head=128,
|
|
||||||
dim_ff=11008,
|
|
||||||
num_layers=32,
|
|
||||||
dropout_p=0.0,
|
|
||||||
activate_fn="silu",
|
|
||||||
scale=False,
|
|
||||||
scale_emb: float = 1.0,
|
|
||||||
scale_depth: float = -1,
|
|
||||||
dim_model_base: int = 256,
|
|
||||||
eps=1e-5,
|
|
||||||
init_std=0.02,
|
|
||||||
dtype="bf16",
|
|
||||||
base=10000,
|
|
||||||
qk_norm=False,
|
|
||||||
tie_lm_head=False,
|
|
||||||
max_length=8192,
|
|
||||||
pose_prob=0.0,
|
|
||||||
pose_scaling_factor=1,
|
|
||||||
rope_scaling_type="",
|
|
||||||
rope_scaling_factor=1,
|
|
||||||
orig_max_length=8192,
|
|
||||||
tp=0,
|
|
||||||
use_checkpoint=True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.dim_model = dim_model
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
self.dim_head = dim_head
|
|
||||||
self.dim_ff = dim_ff
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.dropout_p = dropout_p
|
|
||||||
self.activate_fn = activate_fn
|
|
||||||
self.scale = scale
|
|
||||||
self.scale_emb = scale_emb
|
|
||||||
self._dtype = dtype
|
|
||||||
self.dim_model_base = dim_model_base
|
|
||||||
self.scale_depth = scale_depth
|
|
||||||
self.eps = eps
|
|
||||||
self.init_std = init_std
|
|
||||||
self.base = base
|
|
||||||
self.qk_norm = qk_norm
|
|
||||||
self.tie_lm_head = tie_lm_head
|
|
||||||
self.use_bfloat16 = True if self._dtype == "bf16" else False
|
|
||||||
self.pose_prob = pose_prob
|
|
||||||
self.pose_scaling_factor = pose_scaling_factor
|
|
||||||
self.rope_scaling_type = rope_scaling_type
|
|
||||||
self.rope_scaling_factor = rope_scaling_factor
|
|
||||||
self.max_length = max_length
|
|
||||||
self.orig_max_length = orig_max_length
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
print("use_checkpoint", self.use_checkpoint)
|
|
||||||
self.tp = tp
|
|
||||||
super().__init__(architectures=["fm9gDragonflyForCausalLM"])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def scale_width(
|
|
||||||
self,
|
|
||||||
):
|
|
||||||
if self.scale:
|
|
||||||
return self.dim_model / self.dim_model_base
|
|
||||||
else:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(
|
|
||||||
self,
|
|
||||||
): # -> Any | None:
|
|
||||||
if self._dtype == "bf16":
|
|
||||||
return torch.bfloat16
|
|
||||||
elif self._dtype == "fp16":
|
|
||||||
return torch.half
|
|
||||||
elif self._dtype == "float32":
|
|
||||||
return torch.float
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1 +0,0 @@
|
||||||
from .pretrain_indexed import MixedIndexedDataset
|
|
|
@ -1,74 +0,0 @@
|
||||||
import logging
|
|
||||||
from multiprocessing import Lock
|
|
||||||
|
|
||||||
from flask import Flask
|
|
||||||
from flask import jsonify
|
|
||||||
from flask import request
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
# 获取 Werkzeug 日志记录器并设置日志级别
|
|
||||||
log = logging.getLogger("werkzeug")
|
|
||||||
log.setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalAvgTokensStat(object):
|
|
||||||
def __init__(self, decay_factor: float = 0.98):
|
|
||||||
self._avg_tokens = {}
|
|
||||||
self.decay_factor = decay_factor
|
|
||||||
self.lock = Lock()
|
|
||||||
self.task_locks = {}
|
|
||||||
|
|
||||||
def set_avg_tokens(self, task_name, avg_tokens):
|
|
||||||
self._register_task_lock_helper(task_name)
|
|
||||||
with self.task_locks[task_name]:
|
|
||||||
self._avg_tokens[task_name] = avg_tokens
|
|
||||||
|
|
||||||
def update_avg_tokens_by_ema(self, task_name, length):
|
|
||||||
self._register_task_lock_helper(task_name)
|
|
||||||
with self.task_locks[task_name]:
|
|
||||||
if task_name in self._avg_tokens and self._avg_tokens[task_name] > 0:
|
|
||||||
self._avg_tokens[task_name] = self._avg_tokens[task_name] * self.decay_factor + length * (
|
|
||||||
1 - self.decay_factor
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._avg_tokens[task_name] = length
|
|
||||||
|
|
||||||
def get_avg_tokens(self, task_name):
|
|
||||||
self._register_task_lock_helper(task_name)
|
|
||||||
with self.task_locks[task_name]:
|
|
||||||
return self._avg_tokens.get(task_name, -1)
|
|
||||||
|
|
||||||
def _register_task_lock_helper(self, task_name):
|
|
||||||
if task_name not in self.task_locks:
|
|
||||||
with self.lock:
|
|
||||||
if task_name not in self.task_locks:
|
|
||||||
self.task_locks[task_name] = Lock()
|
|
||||||
|
|
||||||
|
|
||||||
global_avg_tokens_stat = GlobalAvgTokensStat()
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/avg_tokens/<path:task_name>", methods=["GET"])
|
|
||||||
def get_avg_tokens(task_name):
|
|
||||||
global global_avg_tokens_stat
|
|
||||||
avg_tokens = global_avg_tokens_stat.get_avg_tokens(task_name)
|
|
||||||
return jsonify({"avg_tokens": avg_tokens})
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/avg_tokens/<path:task_name>", methods=["POST"])
|
|
||||||
def set_avg_tokens(task_name):
|
|
||||||
global global_avg_tokens_stat
|
|
||||||
action = request.args.get("action", "update", type=str)
|
|
||||||
length = request.args.get("length", -1, type=int)
|
|
||||||
if action == "set":
|
|
||||||
global_avg_tokens_stat.set_avg_tokens(task_name, length)
|
|
||||||
elif action == "update":
|
|
||||||
global_avg_tokens_stat.update_avg_tokens_by_ema(task_name, length)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown action: {action}")
|
|
||||||
return jsonify({"status": "ok"})
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(port=5000, debug=True)
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue