add fm9g 2b and 8b models

This commit is contained in:
anrongqiao 2024-07-15 14:27:10 +08:00
parent 03c55e1fee
commit cfd2fca57c
125 changed files with 302309 additions and 245416 deletions

View File

@ -1,16 +0,0 @@
{
"vocab_size": 119696,
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"use_flash_attn": true,
"flash_attn_mask_shape": "2d",
"dim_model": 4096,
"dim_ff": 12288,
"dim_head": 128,
"num_heads": 32,
"num_kv_heads": 32,
"num_layers": 48,
"activate_fn": "silu",
"scale": false
}

View File

@ -1,14 +0,0 @@
{
"vocab_size": 119696,
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"dim_model": 4096,
"dim_ff": 11008,
"dim_head": 128,
"num_heads": 32,
"num_kv_heads": 32,
"num_layers": 32,
"activate_fn": "silu",
"scale": false
}

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +0,0 @@
[
{
"dataset_name": "wikipedia",
"task_name": "wikipedia",
"weight": 1.0,
"path": "path/to/data",
"incontext_weight": [
1.0
],
"transforms": "/home/USERNAME/cpm9g/apps/cpm9g/config/datasets/wikipedia/script_cpmc.py"
}
]

View File

@ -1,9 +0,0 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
return {"input": "", "output": data["text"]}

View File

@ -1,485 +0,0 @@
import inspect
import json
import math
import os
import re
import sys
import time
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import torch
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
from cpm.arguments import get_args
from cpm.cpm9g.models import CPM9G
from cpm.cpm9g.models import CPM9GConfig
from cpm.cpm9g.tokenizers import CPM9GTokenizer
from cpm.cpm9g.training_tasks import MixedDataset
from cpm.utils import allgather_objects
from cpm.utils import exporter
from cpm.utils import logger
from cpm.utils import LogManager
def get_tokenizer(args):
tokenizer = CPM9GTokenizer(path=args.vocab)
return tokenizer
def get_model(args):
config = CPM9GConfig.from_json_file(args.model_config)
config.tp = 1 if args.tp != 1 else 0
if args.flash == "none":
config.use_flash_attn = False
else:
config.use_flash_attn = True
if args.flash == "1d":
config.flash_attn_mask_shape = "1d"
else:
config.flash_attn_mask_shape = "2d"
if args.flash == "triton":
config.flash_impl = "triton"
elif args.flash == "cuda":
config.flash_impl = "cuda"
model = CPM9G(config)
if args.load is not None:
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
bmt.load(model, args.load)
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
return model
def get_optimizer(args, model):
for name, para in model.named_parameters():
# if not ('input_embedding' in name or 'lm_head' in name):
# para.requires_grad_(False)
bmt.print_rank(name, para.requires_grad)
if args.offload:
optimizer = bmt.optim.AdamOffloadOptimizer(
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
)
else:
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
if args.load is not None and args.load_grad:
start = time.time()
print(
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
)
if (
sum([1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(args.save)])
== bmt.world_size()
):
file_name = os.path.join(
args.save,
args.save_name + "-{}.rank-{}.opt".format(args.start_step, bmt.rank()),
)
print(file_name)
if os.path.exists(file_name):
print("start to load grad ckpt {}".format(file_name))
states = torch.load(file_name)
optimizer.load_state_dict(states)
logger.info("load grad in {:.2f}s".format(time.time() - start))
return optimizer
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
r"""
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
"""
def get_lr_warmup(self, num_iter) -> float:
return self.start_lr * num_iter / self.warmup_iter
def get_lr_decay(self, num_iter) -> float:
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
# lr_scheduler = bmt.lr_scheduler.Noam(
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=args.warmup_iters,
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
return lr_scheduler
def setup_model_and_optimizer(args):
start = time.time()
model = get_model(args)
logger.info("load model in {:.2f}s".format(time.time() - start))
start = time.time()
tokenizer = get_tokenizer(args)
bmt.synchronize()
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
start = time.time()
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
return tokenizer, model, optimizer, lr_scheduler
def initialize():
args = get_args(pretrain=True)
bmt.init_distributed(seed=args.seed, zero_level=3)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
if args.load is not None:
if args.start_step == 0:
args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
bmt.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
class LossSpikeDetector:
def __init__(self, log_path: str) -> None:
self._last_loss: Dict[str, float] = {}
self._last_data: List[Any] = [None]
self._log_path = log_path
def update_data(self, data: Any):
self._last_data.append(data)
if len(self._last_data) > 2:
self._last_data = self._last_data[-2:]
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
loss_spike_result = []
for task, loss in loss_map.items():
if task in self._last_loss:
if loss > self._last_loss[task] * 3:
# loss spike!
loss_spike_result.append(
{
"prev": self._last_loss[task],
"curr": loss,
"task": task,
}
)
self._last_loss[task] = float(loss)
if len(loss_spike_result) > 0:
self._write_log(iteration, self._last_data[-1], loss_spike_result)
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
while True:
try:
with open(self._log_path, "a", encoding="utf-8") as fp:
fp.write("=" * 20)
fp.write("\nloss spike at {}\n".format(iteration))
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
fp.write("data: \n")
for d in data:
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
fp.write("\n\n")
break
except Exception as e:
print("cannot output log to the file {}", self._log_path)
def pretrain(
args,
tokenizer: CPM9GTokenizer,
model: CPM9G,
optimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
):
average_time = bmt.utils.AverageRecorder()
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
optim_manager = bmt.optim.OptimManager(
loss_scale=None if args.bf16 else args.loss_scale,
loss_scale_steps=args.loss_scale_steps,
loss_scale_factor=2,
max_loss_scale=args.max_loss_scale,
min_loss_scale=args.min_loss_scale,
)
optim_manager.add_optimizer(optimizer, lr_scheduler)
start_step = args.start_step
lsd = LossSpikeDetector("./log/debug/spile.%d.log" % bmt.rank())
if args.tensorboard is not None and bmt.rank() == 0:
import distutils.version # noqa: F401
from tensorboardX import SummaryWriter
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
if args.log_dir is not None and bmt.rank() == 0:
log_mgr = LogManager(args.log_dir)
global_token_pass = 0.0
global_world_size = bmt.world_size()
dataloader = MixedDataset(args.dataset, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"))
if args.load is not None:
dataset_states_path = args.load.replace(".pt", ".data")
if os.path.exists(dataset_states_path):
start = time.time()
bmt.print_rank("start to load data ckpt")
dataset_states = torch.load(dataset_states_path)
logger.info("load data ckpt in {:.2f}s".format(time.time() - start))
start = time.time()
missing = dataloader.load_state_dict(dataset_states)
logger.info("load state dict in {:.2f}s".format(time.time() - start))
if len(missing) > 0:
bmt.print_rank("Missing keys when loading dataset states: ", missing)
else:
bmt.print_rank("cannot find data ckpt {}".format(dataset_states_path))
dataloader.start()
bmt.print_rank("finish dataset start")
try:
total = 0
hash = {}
for iteration, data in enumerate(dataloader):
iteration = iteration + start_step + 1
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
targets = torch.from_numpy(data["target"]).cuda().to(torch.int32)
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
task_names = data["task_names"]
lsd.update_data(data["raw_data"])
if args.flash == "cuda":
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
max_seqlen = data["max_seqlen"]
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
else:
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
input_context = torch.zeros_like(input_ids).cuda().bool()
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
# ===========
optim_manager.zero_grad()
# torch.cuda.empty_cache()
mem_usage = {}
tim_usage = {}
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
# bmt.print_rank(torch.cuda.max_memory_allocated())
# ===========
if args.flash == "cuda":
logits, _ = model(
input_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
else:
logits, _ = model(
input_ids,
input_length,
input_context,
input_span,
)
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
global_loss = bmt.sum_loss(loss).item()
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
# bmt.print_rank(torch.cuda.max_memory_allocated())
# ===========
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
# bmt.print_rank(torch.cuda.max_memory_allocated())
# ===========
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
# bmt.print_rank(torch.cuda.max_memory_allocated())
# ==========
iter_time = tim_usage["optim"] - tim_usage["init"]
average_time.record(iter_time)
with torch.no_grad():
task_num = len(task_names)
targets_tmp = targets.expand(task_num, -1, -1)
task = torch.arange(task_num, dtype=torch.int32, device="cuda")[:, None, None]
targets_tmp = torch.where(
task_ids == task,
targets_tmp,
torch.scalar_tensor(-100, dtype=torch.int32, device="cuda"),
)
task_loss_map: Dict[str, float] = {}
task_loss_tot: Dict[str, float] = {}
for i in range(task_num):
task_loss_map[task_names[i]] = loss_func(
logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1)
).item()
task_loss_tot[task_names[i]] = (targets_tmp[i, :].view(-1) >= 0).sum().float().item()
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
gatherd_task_loss_tot: List[Dict[str, float]] = allgather_objects(task_loss_tot)
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
global_task_loss_tot: Dict[str, Union[List[float], float]] = {}
for idx, local_task_loss_map in enumerate(gatherd_task_loss_map):
for task_name, task_loss in local_task_loss_map.items():
if task_name not in global_task_loss_map:
global_task_loss_map[task_name] = []
global_task_loss_map[task_name].append(task_loss)
for task_name, task_tot in gatherd_task_loss_tot[idx].items():
if task_name not in global_task_loss_tot:
global_task_loss_tot[task_name] = []
global_task_loss_tot[task_name].append(task_tot)
task_loss_map = {}
for task_name in sorted(list(global_task_loss_map.keys())):
avg_loss = 0.0
sum_token = sum(global_task_loss_tot[task_name])
for loss, token in zip(global_task_loss_map[task_name], global_task_loss_tot[task_name]):
avg_loss += loss * token / sum_token
task_loss_map[task_name] = avg_loss
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
avg_time = average_time.value
lsd.update_loss(iteration, task_loss_map)
for task_id in data["task_ids"]:
for task in task_id:
if task != -1:
if not data["task_names"][task] in hash:
hash[data["task_names"][task]] = 0
hash[data["task_names"][task]] += 1.0
total += 1.0
gathered_hash = allgather_objects(hash)
sum_total = sum(allgather_objects(total))
final_hash = defaultdict(int)
for local_hash in gathered_hash:
for task, num in local_hash.items():
final_hash[task] += num
# for i in final_hash:
# bmt.print_rank(i, final_hash[i] / sum_total)
# bmt.print_rank("=========================================")
train_info = {
"time": tim_usage["init"],
"iteration": iteration,
"loss": global_loss,
"lr": lr_scheduler.current_lr,
"lr_scale": int(optim_manager.loss_scale),
"time_usage": tim_usage,
"mem_usage": mem_usage,
"avg_time": avg_time,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
"grad_norm": grad_norm.item(),
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
"num_gpus": global_world_size,
"task_loss": task_loss_map,
}
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
bmt.print_rank(
(
"| Iter: {:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | avg_time: {:.4f}; cur_time:{:.4f}={:.4f}+{:.4f} |"
+ " token/max: {:.4f} | mask/max: {:.4f} | grad_norm: {:.4f} | mem: {:.2f} |"
).format(
iteration,
global_loss,
lr_scheduler.current_lr,
int(optim_manager.loss_scale),
iter_time,
tim_usage["optim"] - tim_usage["init"],
tim_usage["backward"] - tim_usage["init"],
tim_usage["optim"] - tim_usage["backward"],
input_length.float().mean() / args.max_length / (args.batch_size if args.flash == "cuda" else 1),
(targets >= 0).sum(-1).float().mean()
/ args.max_length
/ (args.batch_size if args.flash == "cuda" else 1),
grad_norm,
max(mem_usage["forward"][1], mem_usage["backward"][1]),
)
)
bmt.print_rank(
"| "
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
+ " |"
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
train_info["model_inspect"] = model_inspect
if args.log_dir is not None and bmt.rank() == 0:
log_mgr.write(**train_info)
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar("Loss/train", global_loss, iteration)
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
for task_name, loss in task_loss_map.items():
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
if args.save is not None and iteration % args.save_iters == 0:
exporter.export(model, dataloader, optimizer, iteration, args, final_save=False)
if iteration >= args.train_iters:
break
except Exception as e:
print(f"train loop err: {e}")
raise e
finally:
dataloader.close()
exporter.export(model, dataloader, optimizer, -1, args, final_save=False)
def main():
args = initialize()
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
bmt.print_rank("finish loading")
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
if __name__ == "__main__":
main()

View File

@ -1,60 +0,0 @@
#! /bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=8
#SBATCH --mem=512GB
# use 8 GPU for example, pretrain may need 32 GPU
export MASTER_ADDR=`hostname`
export MASTER_PORT=12345
mkdir -p /home/${USERNAME}/logs/debug
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
cd apps/cpm9g
CONFIG_NAME="config/11b"
# --------------- 运行参数 ---------------
OPTS=""
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
OPTS+=" --batch-size 4"
OPTS+=" --train-iters 400000"
OPTS+=" --save-iters 250"
OPTS+=" --save-name cpm9g_checkpoint"
OPTS+=" --max-length 4096"
OPTS+=" --lr 1.5e-5"
OPTS+=" --inspect-iters 100"
OPTS+=" --warmup-iters 2000"
OPTS+=" --lr-decay-style noam"
OPTS+=" --weight-decay 0.1"
OPTS+=" --clip-grad 1.0"
OPTS+=" --loss-scale 1048576"
OPTS+=" --loss-scale-steps 32"
OPTS+=" --offload"
OPTS+=" --flash cuda"
# OPTS+=" --load-grad"
# --------------- 写文件路径 ---------------
## checkpoint
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
## logs/local/logs 等价于 /data/logs软链
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
# --------------- 读文件路径 ---------------
OPTS+=" --dataset config/datasets.json"
OPTS+=" --load ${CHECKPOINT}"
OPTS+=" --start-step 1"
# --------------- 透传参数 ---------------
OPTS+=" $@"
# --------------- 最终指令 ---------------
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
echo "${CMD}"
$CMD

View File

@ -1,59 +0,0 @@
#! /bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=8
# use 8 GPU for example, pretrain may need 32 GPU
export MASTER_ADDR=`hostname`
export MASTER_PORT=12345
mkdir -p /home/${USERNAME}/logs/debug
mkdir -p /home/${USERNAME}/logs/tensorboard/cpm9g/
cd apps/cpm9g
CONFIG_NAME="config/7b"
# --------------- 运行参数 ---------------
OPTS=""
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
OPTS+=" --batch-size 4"
OPTS+=" --train-iters 400000"
OPTS+=" --save-iters 250"
OPTS+=" --save-name cpm9g_checkpoint"
OPTS+=" --max-length 4096"
OPTS+=" --lr 1.5e-5"
OPTS+=" --inspect-iters 100"
OPTS+=" --warmup-iters 2000"
OPTS+=" --lr-decay-style noam"
OPTS+=" --weight-decay 0.1"
OPTS+=" --clip-grad 1.0"
OPTS+=" --loss-scale 1048576"
OPTS+=" --loss-scale-steps 32"
OPTS+=" --offload"
OPTS+=" --flash cuda"
# OPTS+=" --load-grad"
# --------------- 写文件路径 ---------------
## checkpoint
OPTS+=" --save /home/${USERNAME}/checkpoints/cpm9g/"
OPTS+=" --save-model /home/${USERNAME}/models/cpm9g/"
## logs/local/logs 等价于 /data/logs软链
OPTS+=" --log-dir /home/${USERNAME}/logs/train/"
OPTS+=" --tensorboard /home/${USERNAME}/tensorboard/cpm9g/"`date +"%Y%m%d%H%M%S"`
# --------------- 读文件路径 ---------------
OPTS+=" --dataset config/datasets.json"
OPTS+=" --load ${CHECKPOINT}"
OPTS+=" --start-step 1"
# --------------- 透传参数 ---------------
OPTS+=" $@"
# --------------- 最终指令 ---------------
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} pretrain_cpm9g.py ${OPTS}"
echo "${CMD}"
$CMD

View File

@ -1,484 +0,0 @@
# coding=utf-8
# Copyright 2022 The OpenBMB team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import math
import os
import re
import sys
import time
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import torch
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
from cpm.arguments import get_args
from cpm.cpm9g.models import CPM9G
from cpm.cpm9g.models import CPM9GConfig
from cpm.cpm9g.tokenizers import CPM9GTokenizer
from cpm.cpm9g.training_tasks import FinetuneDataset
from cpm.utils import allgather_objects
from cpm.utils import logger
import shutil
def get_tokenizer(args):
tokenizer = CPM9GTokenizer(path=args.vocab)
return tokenizer
def get_model(args):
config = CPM9GConfig.from_json_file(args.model_config)
if args.flash == "none":
config.use_flash_attn = False
else:
config.use_flash_attn = True
if args.flash == "1d":
config.flash_attn_mask_shape = "1d"
else:
config.flash_attn_mask_shape = "2d"
if args.flash == "triton":
config.flash_impl = "triton"
elif args.flash == "cuda":
config.flash_impl = "cuda"
model = CPM9G(config)
if args.load is not None:
bmt.init_parameters(model)
bmt.synchronize()
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
bmt.load(model, args.load, strict=False)
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
return model
def get_optimizer(args, model):
if args.offload:
optimizer = bmt.optim.AdamOffloadOptimizer(
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
)
else:
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
if args.load is not None and args.load_grad:
start = time.time()
print(
sum(
[
1
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
else 0
for i in os.listdir(args.save)
]
)
)
if (
sum(
[
1
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
else 0
for i in os.listdir(args.save)
]
)
== bmt.world_size()
):
file_name = os.path.join(
args.save,
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
)
print(file_name)
if os.path.exists(file_name):
print("start to load grad ckpt {}".format(file_name))
states = torch.load(file_name)
optimizer.load_state_dict(states)
logger.info("load grad in {:.2f}s".format(time.time() - start))
return optimizer
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
r"""
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
"""
def get_lr_warmup(self, num_iter) -> float:
return self.start_lr * num_iter / self.warmup_iter
def get_lr_decay(self, num_iter) -> float:
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
# lr_scheduler = bmt.lr_scheduler.Noam(
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=args.warmup_iters,
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
return lr_scheduler
def setup_model_and_optimizer(args):
start = time.time()
model = get_model(args)
logger.info("load model in {:.2f}s".format(time.time() - start))
start = time.time()
tokenizer = get_tokenizer(args)
bmt.synchronize()
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
start = time.time()
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
return tokenizer, model, optimizer, lr_scheduler
def initialize():
args = get_args(finetune=True)
# hack
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
bmt.init_distributed(checkpointing=False, seed=args.seed)
else:
bmt.init_distributed(seed=args.seed)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
# if args.load is not None:
# if args.start_step == 0:
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
class LossSpikeDetector:
def __init__(self, log_path: str) -> None:
self._last_loss: Dict[str, float] = {}
self._last_data: List[Any] = [None]
self._log_path = log_path
def update_data(self, data: Any):
self._last_data.append(data)
if len(self._last_data) > 2:
self._last_data = self._last_data[-2:]
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
loss_spike_result = []
for task, loss in loss_map.items():
if task in self._last_loss:
if loss > self._last_loss[task] * 3:
# loss spike!
loss_spike_result.append(
{
"prev": self._last_loss[task],
"curr": loss,
"task": task,
}
)
self._last_loss[task] = float(loss)
if len(loss_spike_result) > 0:
self._write_log(iteration, self._last_data[-1], loss_spike_result)
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
return
with open(self._log_path, "a", encoding="utf-8") as fp:
fp.write("=" * 20)
fp.write("\nloss spike at {}\n".format(iteration))
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
fp.write("data: \n")
for d in data:
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
fp.write("\n\n")
def finetune(
args,
bin_file: str,
tokenizer: CPM9GTokenizer,
model: CPM9G,
optimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
):
average_time = bmt.utils.AverageRecorder()
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
optim_manager.add_optimizer(optimizer, lr_scheduler)
if args.tensorboard is not None and bmt.rank() == 0:
import distutils.version # noqa: F401
from tensorboardX import SummaryWriter
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
global_token_pass = 0.0
global_world_size = bmt.world_size()
for epoch in range(args.epoch):
epoch = epoch + 1
last_data = None
dataloader = FinetuneDataset(
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
)
optim_manager.zero_grad()
for iteration, data in enumerate(dataloader):
iteration = iteration + 1
skip_this_batch = False
if data is None:
if last_data is None:
raise RuntimeError(
"Dataset is too small, please use a smaller batch size or sequence length!"
)
data = last_data # use last data
skip_this_batch = True
else:
last_data = data
assert data["inputs"].shape[0] == args.batch_size
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
targets = torch.from_numpy(data["target"]).cuda().long()
# bmt.print_rank(input_ids[0].tolist())
# bmt.print_rank(targets[0].tolist())
# bmt.print_rank(data["spans"].tolist())
# bmt.print_rank(tokenizer.decode(input_ids[0]))
# bmt.print_rank(tokenizer.path)
# bmt.synchronize()
# exit()
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
task_names = data["task_names"]
if args.flash == "cuda":
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
max_seqlen = data["max_seqlen"]
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
else:
input_context = torch.zeros_like(input_ids).cuda().bool()
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
# ===========
# optim_manager.zero_grad()
# torch.cuda.empty_cache()
mem_usage = {}
tim_usage = {}
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
# ===========
if args.flash == "cuda":
logits, _ = model(
input_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
else:
logits, _ = model(
input_ids,
input_length,
input_context,
input_span,
)
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
if skip_this_batch:
loss = loss * 0
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
# ===========
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
# ===========
if iteration % args.gradient_accumulation_steps == 0:
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
optim_manager.zero_grad()
else:
grad_norm = None
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
# ==========
iter_time = tim_usage["optim"] - tim_usage["init"]
average_time.record(iter_time)
with torch.no_grad():
task_num = len(task_names)
targets_tmp = targets.expand(task_num, -1, -1)
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
targets_tmp = torch.where(
task_ids == task,
targets_tmp,
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
)
task_loss_map: Dict[str, float] = {}
if not skip_this_batch:
for i in range(task_num):
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
task_loss_map[task_names[i]] = task_loss.item()
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
for local_task_loss_map in gatherd_task_loss_map:
for task_name, task_loss in local_task_loss_map.items():
if task_name not in global_task_loss_map:
global_task_loss_map[task_name] = []
global_task_loss_map[task_name].append(task_loss)
task_loss_map = {}
for task_name in sorted(list(global_task_loss_map.keys())):
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
task_loss_map[task_name] = avg_loss
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
avg_time = average_time.value
train_info = {
"time": tim_usage["init"],
"epoch": epoch,
"iteration": iteration,
"loss": task_loss_map[args.task_name],
"lr": lr_scheduler.current_lr,
"lr_scale": int(optim_manager.loss_scale),
"time_usage": tim_usage,
"mem_usage": mem_usage,
"avg_time": avg_time,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
"grad_norm": 0 if grad_norm == None else grad_norm.item(),
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
"num_gpus": global_world_size,
"task_loss": task_loss_map,
}
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
bmt.print_rank(
(
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
).format(
epoch,
iteration,
args.train_iters,
task_loss_map[args.task_name],
lr_scheduler.current_lr,
int(optim_manager.loss_scale),
avg_time,
input_length.float().mean() / args.max_length,
(targets >= 0).sum(-1).float().mean() / args.max_length,
0 if grad_norm == None else grad_norm,
)
)
bmt.print_rank(
"| "
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
+ " |"
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
train_info["model_inspect"] = model_inspect
# save_folder_name = f"{args.save}{epoch}{iteration}"
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
# if bmt.rank() == 0:
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", 0 if grad_norm == None else grad_norm.item(), iteration)
# if iteration % 10 == 0:
# save_folder_name = f"{args.save}{epoch}"
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
# bmt.save(model, model_fname)
# if bmt.rank() == 0:
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
save_folder_name = f"{args.save}-epoch-{epoch}"
model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
bmt.save(model, model_fname)
if bmt.rank() == 0:
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
def main():
args = initialize()
# To Be Specified
bin_file = args.dataset
bmt.print_rank(f"dataset: {bin_file}")
tokenizer = get_tokenizer(args)
dataloader = FinetuneDataset(
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
)
bmt.print_rank(f"#batch: {len(dataloader)}")
total_steps = len(dataloader) * args.epoch
setattr(args, 'train_iters', int(total_steps))
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
bmt.print_rank("finish loading")
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
if __name__ == "__main__":
main()

View File

@ -1,55 +0,0 @@
#! /bin/bash
#SBATCH --partition=gpu3-1
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=8
export MASTER_ADDR=g3002
export MASTER_PORT=12345
CPM_PATH="/data/groups/QY_LLM_Core/projects/202311-release/Models/11B-Chat/9G-Train"
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
EXP_PATH=./exp
mkdir -p $EXP_PATH
MODEL_NAME="cpm9g-11b-sft"
OPTS=""
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
OPTS+=" --train-iters 10000"
OPTS+=" --inspect-iters 200"
OPTS+=" --warmup-iters 500"
OPTS+=" --lr-decay-style cosine"
OPTS+=" --weight-decay 0.1"
OPTS+=" --clip-grad 1.0"
OPTS+=" --loss-scale 1048576"
OPTS+=" --max-loss-scale 33554432"
OPTS+=" --min-loss-scale 1"
OPTS+=" --loss-scale-steps 32"
OPTS+=" --offload"
OPTS+=" --batch-size 4"
OPTS+=" --max-length 4096"
OPTS+=" --lr 2e-5"
OPTS+=" --start-step 0"
OPTS+=" --epoch 8"
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
OPTS+=" --save ${EXP_PATH}/checkpoints"
OPTS+=" --save-name ${MODEL_NAME}"
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
# OPTS+=" --flash triton"
# OPTS+=" --flash cuda"
# OPTS+=" --load-grad"
OPTS+=" $@"
CMD="torchrun --nnodes=4 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g.py ${OPTS}"
echo "${CMD}"
$CMD

View File

@ -1,515 +0,0 @@
# coding=utf-8
# Copyright 2022 The OpenBMB team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import math
import os
import re
import sys
import time
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import torch
sys.path.insert(0, "/home/wangshuo1/code/9G-Train")
from cpm.arguments import get_args
from cpm.cpm9g.models import CPM9G
from cpm.cpm9g.models import CPM9GConfig
from cpm.cpm9g.tokenizers import CPM9GTokenizer
from cpm.cpm9g.training_tasks import FinetuneDataset
from cpm.utils import allgather_objects
from cpm.utils import logger
import shutil
import opendelta as od
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel, ParallelAdapterModel
from opendelta.utils.inspect import inspect_optimizer_statistics
from bigmodelvis import Visualization
def get_tokenizer(args):
tokenizer = CPM9GTokenizer(path=args.vocab)
return tokenizer
def get_model(args):
config = CPM9GConfig.from_json_file(args.model_config)
if args.flash == "none":
config.use_flash_attn = False
else:
config.use_flash_attn = True
if args.flash == "1d":
config.flash_attn_mask_shape = "1d"
else:
config.flash_attn_mask_shape = "2d"
if args.flash == "triton":
config.flash_impl = "triton"
elif args.flash == "cuda":
config.flash_impl = "cuda"
model = CPM9G(config)
if args.load is not None:
bmt.init_parameters(model)
bmt.synchronize()
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
bmt.load(model, args.load, strict=False)
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
if args.delta_type != None:
from bigmodelvis import Visualization #0813
if bmt.rank() == 0:
Visualization(model).structure_graph()
print("finetuned layers: ")
print(args.lora_layer)
if args.delta_type == "lora":
delta_model = LoraModel(backbone_model=model, modified_modules=args.lora_layer, backend='bmt', lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
if bmt.rank() == 0:
print("Before freeze: ")
delta_model.log()
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
if bmt.rank() == 0:
print("After freeze: ")
delta_model.log()
return model
def get_optimizer(args, model):
if args.offload:
optimizer = bmt.optim.AdamOffloadOptimizer(
model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay
)
else:
optimizer = bmt.optim.AdamOptimizer(model.parameters(), betas=(0.9, 0.95), weight_decay=args.weight_decay)
if args.load is not None and args.load_grad:
start = time.time()
print(
sum(
[
1
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
else 0
for i in os.listdir(args.save)
]
)
)
if (
sum(
[
1
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
else 0
for i in os.listdir(args.save)
]
)
== bmt.world_size()
):
file_name = os.path.join(
args.save,
args.save_name + "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank()),
)
print(file_name)
if os.path.exists(file_name):
print("start to load grad ckpt {}".format(file_name))
states = torch.load(file_name)
optimizer.load_state_dict(states)
logger.info("load grad in {:.2f}s".format(time.time() - start))
return optimizer
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
r"""
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
"""
def get_lr_warmup(self, num_iter) -> float:
return self.start_lr * num_iter / self.warmup_iter
def get_lr_decay(self, num_iter) -> float:
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
# lr_scheduler = bmt.lr_scheduler.Noam(
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=args.warmup_iters,
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
return lr_scheduler
def setup_model_and_optimizer(args):
start = time.time()
model = get_model(args)
logger.info("load model in {:.2f}s".format(time.time() - start))
start = time.time()
tokenizer = get_tokenizer(args)
bmt.synchronize()
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
start = time.time()
optimizer = get_optimizer(args, model)
if args.delta_type != None:
inspect_optimizer_statistics(optimizer)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
return tokenizer, model, optimizer, lr_scheduler
def initialize():
args = get_args(finetune=True)
# hack
if "checkpointing" in inspect.signature(bmt.init_distributed).parameters:
bmt.init_distributed(checkpointing=False, seed=args.seed)
else:
bmt.init_distributed(seed=args.seed)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
# if args.load is not None:
# if args.start_step == 0:
# args.start_step = (int)(re.search("(\d+).pt", args.load)[1])
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
class LossSpikeDetector:
def __init__(self, log_path: str) -> None:
self._last_loss: Dict[str, float] = {}
self._last_data: List[Any] = [None]
self._log_path = log_path
def update_data(self, data: Any):
self._last_data.append(data)
if len(self._last_data) > 2:
self._last_data = self._last_data[-2:]
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
loss_spike_result = []
for task, loss in loss_map.items():
if task in self._last_loss:
if loss > self._last_loss[task] * 3:
# loss spike!
loss_spike_result.append(
{
"prev": self._last_loss[task],
"curr": loss,
"task": task,
}
)
self._last_loss[task] = float(loss)
if len(loss_spike_result) > 0:
self._write_log(iteration, self._last_data[-1], loss_spike_result)
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
return
with open(self._log_path, "a", encoding="utf-8") as fp:
fp.write("=" * 20)
fp.write("\nloss spike at {}\n".format(iteration))
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
fp.write("data: \n")
for d in data:
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
fp.write("\n\n")
def finetune(
args,
bin_file: str,
tokenizer: CPM9GTokenizer,
model: CPM9G,
optimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
):
average_time = bmt.utils.AverageRecorder()
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale, loss_scale_steps=args.loss_scale_steps, loss_scale_factor=2, max_loss_scale=args.max_loss_scale, min_loss_scale=args.min_loss_scale,)
optim_manager.add_optimizer(optimizer, lr_scheduler)
if args.tensorboard is not None and bmt.rank() == 0:
import distutils.version # noqa: F401
from tensorboardX import SummaryWriter
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
global_token_pass = 0.0
global_world_size = bmt.world_size()
for epoch in range(args.epoch):
epoch = epoch + 1
last_data = None
dataloader = FinetuneDataset(
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
)
for iteration, data in enumerate(dataloader):
iteration = iteration + 1
skip_this_batch = False
if data is None:
if last_data is None:
raise RuntimeError(
"Dataset is too small, please use a smaller batch size or sequence length!"
)
data = last_data # use last data
skip_this_batch = True
else:
last_data = data
assert data["inputs"].shape[0] == args.batch_size
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
targets = torch.from_numpy(data["target"]).cuda().long()
# bmt.print_rank(input_ids[0].tolist())
# bmt.print_rank(targets[0].tolist())
# bmt.print_rank(data["spans"].tolist())
# bmt.print_rank(tokenizer.decode(input_ids[0]))
# bmt.print_rank(tokenizer.path)
# bmt.synchronize()
# exit()
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
task_names = data["task_names"]
if args.flash == "cuda":
cu_seqlens = torch.from_numpy(data["cu_seqlens"]).cuda().to(torch.int32)
max_seqlen = data["max_seqlen"]
position_ids = torch.from_numpy(data["position_ids"]).cuda().to(torch.int32)
else:
input_context = torch.zeros_like(input_ids).cuda().bool()
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
# ===========
optim_manager.zero_grad()
# torch.cuda.empty_cache()
mem_usage = {}
tim_usage = {}
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
# ===========
if args.flash == "cuda":
logits, _ = model(
input_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
else:
logits, _ = model(
input_ids,
input_length,
input_context,
input_span,
)
mem_usage, tim_usage = add_mem_time("forward_1", mem_usage, tim_usage)
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
if skip_this_batch:
loss = loss * 0
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
# ===========
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
# ===========
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
# ==========
iter_time = tim_usage["optim"] - tim_usage["init"]
average_time.record(iter_time)
with torch.no_grad():
task_num = len(task_names)
targets_tmp = targets.expand(task_num, -1, -1)
task = torch.arange(task_num, dtype=torch.long, device="cuda")[:, None, None]
targets_tmp = torch.where(
task_ids == task,
targets_tmp,
torch.scalar_tensor(-100, dtype=torch.long, device="cuda"),
)
task_loss_map: Dict[str, float] = {}
if not skip_this_batch:
for i in range(task_num):
task_loss = loss_func(logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1))
task_loss_map[task_names[i]] = task_loss.item()
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
for local_task_loss_map in gatherd_task_loss_map:
for task_name, task_loss in local_task_loss_map.items():
if task_name not in global_task_loss_map:
global_task_loss_map[task_name] = []
global_task_loss_map[task_name].append(task_loss)
task_loss_map = {}
for task_name in sorted(list(global_task_loss_map.keys())):
avg_loss = sum(global_task_loss_map[task_name]) / len(global_task_loss_map[task_name])
task_loss_map[task_name] = avg_loss
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += global_world_size * local_total_rate * args.max_length * args.batch_size
avg_time = average_time.value
train_info = {
"time": tim_usage["init"],
"epoch": epoch,
"iteration": iteration,
"loss": task_loss_map[args.task_name],
"lr": lr_scheduler.current_lr,
"lr_scale": int(optim_manager.loss_scale),
"time_usage": tim_usage,
"mem_usage": mem_usage,
"avg_time": avg_time,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
"grad_norm": grad_norm.item(),
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
"num_gpus": global_world_size,
"task_loss": task_loss_map,
}
# bmt.print_rank(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
bmt.print_rank(
(
"| Epoch: {:3d} | Iter: {:6d}/{:6d} | loss: {:.4f} | lr: {:.6e} | scale: {:10.0f} | time: {:.1f} |"
+ " token/max: {:.3f} | mask/max: {:.3f} | grad_norm: {:.3f}"
).format(
epoch,
iteration,
args.train_iters,
task_loss_map[args.task_name],
lr_scheduler.current_lr,
int(optim_manager.loss_scale),
avg_time,
input_length.float().mean() / args.max_length,
(targets >= 0).sum(-1).float().mean() / args.max_length,
grad_norm,
)
)
bmt.print_rank(
"| "
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
+ " |"
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
train_info["model_inspect"] = model_inspect
# save_folder_name = f"{args.save}{epoch}{iteration}"
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-iter-{iteration}.pt")
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
# if bmt.rank() == 0:
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar(f"Loss/train/{epoch}", task_loss_map[args.task_name], iteration)
writer.add_scalar(f"Optimizer/lr/{epoch}", lr_scheduler.current_lr, iteration)
writer.add_scalar(f"Optimizer/scale/{epoch}", optim_manager.loss_scale, iteration)
writer.add_scalar(f"Optimizer/grad_norm/{epoch}", grad_norm.item(), iteration)
# if iteration % 10 == 0:
# save_folder_name = f"{args.save}{epoch}"
# model_fname = os.path.join(save_folder_name, f"{args.save_name}-epoch-{epoch}.pt")
# os.makedirs(os.path.dirname(model_fname), exist_ok=True)
# bmt.save(model, model_fname)
# if bmt.rank() == 0:
# shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
# shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
save_folder_name = f"{args.save}/epoch-{epoch}"
model_fname = os.path.join(save_folder_name, f"{args.save_name}.pt")
os.makedirs(os.path.dirname(model_fname), exist_ok=True)
state_dict = model.state_dict()
if args.delta_type == None or args.save_origin_model == True :
print("saving base model...")
bmt.save(model, model_fname)
if args.delta_type != None and bmt.rank() == 0:
print("saving delta model...")
torch.save(state_dict, os.path.join(save_folder_name, f"{args.save_name}-delta.pt"))
if bmt.rank() == 0:
shutil.copy(args.model_config, os.path.join(save_folder_name, "config.json"))
shutil.copy(args.vocab, os.path.join(save_folder_name, "vocabs.txt"))
def main():
args = initialize()
# To Be Specified
bin_file = args.dataset
bmt.print_rank(f"dataset: {bin_file}")
tokenizer = get_tokenizer(args)
dataloader = FinetuneDataset(
bin_file, args.batch_size, args.max_length, tokenizer, unpad=(args.flash == "cuda"), task_name="task", drop_last=True
)
bmt.print_rank(f"#batch: {len(dataloader)}")
total_steps = len(dataloader) * args.epoch
setattr(args, 'train_iters', int(total_steps))
setattr(args, 'warmup_iters', max(int(total_steps*0.02), args.warmup_iters))
bmt.print_rank(json.dumps(vars(args), indent=2, sort_keys=True))
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
bmt.print_rank("finish loading")
finetune(args, bin_file, tokenizer, model, optimizer, lr_scheduler)
if __name__ == "__main__":
main()

View File

@ -1,63 +0,0 @@
#! /bin/bash
#SBATCH --partition=gpu3-1
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=8
export MASTER_ADDR=`hostname`
export MASTER_PORT=12345
echo $MASTER_ADDR
CPM_PATH="./9G-Train"
CONFIG_NAME="${CPM_PATH}/apps/cpm9g/config/11b"
EXP_PATH=./exp
mkdir -p $EXP_PATH
MODEL_NAME="cpm9g-11b-sft"
OPTS=""
OPTS+=" --model-config ${CONFIG_NAME}/config.json"
OPTS+=" --vocab ${CONFIG_NAME}/vocab.txt"
OPTS+=" --train-iters 10000"
OPTS+=" --inspect-iters 200"
OPTS+=" --warmup-iters 500"
OPTS+=" --lr-decay-style cosine"
OPTS+=" --weight-decay 0.1"
OPTS+=" --clip-grad 1.0"
OPTS+=" --loss-scale 1048576"
OPTS+=" --max-loss-scale 33554432"
OPTS+=" --min-loss-scale 1"
OPTS+=" --loss-scale-steps 32"
OPTS+=" --offload"
OPTS+=" --batch-size 4"
OPTS+=" --max-length 4096"
OPTS+=" --lr 2e-5"
OPTS+=" --start-step 0"
OPTS+=" --epoch 8"
OPTS+=" --load /data/groups/QY_LLM_Core/models/20231010/11b-base/11b.pt"
OPTS+=" --dataset /data/groups/QY_LLM_Core/datasets/sft/20231025/merge_qy_sft_bin"
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
OPTS+=" --save ${EXP_PATH}/checkpoints"
OPTS+=" --save-name ${MODEL_NAME}"
# OPTS+=" --tensorboard /data/logs/tensorboard/${MODEL_NAME}/${CUR_DATE}/"
# OPTS+=" --flash triton"
# OPTS+=" --flash cuda"
# OPTS+=" --load-grad"
OPTS+=" --delta-tuning" #开启delta-tuning
OPTS+=" --delta-type lora" #目前仅支持lora
OPTS+=" --lora-r 64" #lora矩阵的维度默认为8
OPTS+=" --lora-dropout 0.05" #默认为0
OPTS+=" --lora-alpha 64" #lora对原模型的影响比例默认为8
OPTS+=" --lora-layer project_q project_v project_k w_0 w_1 w_out" #参与lora的线性层默认为project_q,project_k
OPTS+=" --save-origin-model" #是否在每个epoch储存基座模型
OPTS+=" $@"
CMD="torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g_delta.py ${OPTS}"
echo "${CMD}"
$CMD

View File

@ -1,5 +0,0 @@
from .models import CPM9G
from .models import CPM9GConfig
from .tokenizers import CPM9GTokenizer
from .training_tasks import MixedDataset

View File

@ -1,16 +0,0 @@
{
"vocab_size": 119696,
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"use_flash_attn": false,
"flash_attn_mask_shape": "1d",
"dim_model": 4096,
"dim_ff": 12288,
"dim_head": 128,
"num_heads": 32,
"num_kv_heads": 32,
"num_layers": 32,
"activate_fn": "silu",
"scale": false
}

View File

@ -1,658 +0,0 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from ...generation import apply_repetition_penalty
from ...generation import BeamHypotheses
from ...generation import top_k_top_p_filtering
from ...utils import pad
from ..models import CPM9GTorch
from ..tokenizers.cpm9g import CPM9GTokenizer
class CPM9GGeneration:
def __init__(
self, model: CPM9GTorch, tokenizer: CPM9GTokenizer, max_in_len=1024, use_nbce: bool = False
):
model.eval()
self.model = model
self.tokenizer = tokenizer
self.max_in_len = max_in_len
def _convert_to_tensors(self, data: Any):
input_ids = self.tokenizer.encode(
data["input"]
) # [self.tokenizer.bos_token_id] + self.tokenizer.encode(data["input"])
model_input = {}
model_input["input_ids"] = torch.tensor(input_ids[: self.max_in_len], dtype=torch.int32).unsqueeze(0)
model_input["context"] = torch.zeros(
(model_input["input_ids"].shape[0], model_input["input_ids"].shape[1]), dtype=torch.int16
)
model_input["span"] = torch.ones((model_input["input_ids"].shape[1],), dtype=torch.int16).unsqueeze(0)
model_input["length"] = torch.tensor([model_input["input_ids"].shape[1]], dtype=torch.int16).unsqueeze(0)
return model_input
def _process_list(self, data_list: List[Any]):
input_tensors = list(map(self._convert_to_tensors, data_list))
keys = set(input_tensors[0].keys())
padded = {}
for key in keys:
padded[key] = pad(input_tensors, key, padding_side="left").cuda()
return padded
def generate(self, data_list, **kwargs):
origin_data_list = data_list.copy()
model_inputs = self._process_list(data_list)
with torch.inference_mode():
result_ids = self._decode(model_inputs, **kwargs)
return result_ids
def _decode(self, model_inputs, **kwargs):
raise NotImplementedError("_decode is not implemented.")
class CPM9GBeamSearch(CPM9GGeneration):
def _decode(
self,
model_inputs,
beam_size=4,
max_length=100,
repetition_penalty=1.2,
repetition_window=None,
):
"""
Beam search
Args:
model_inputs (dict): input ids.
beam_size (int, optional, defaults to 3): beam size of beam search.
generate_length (int, optional, defaults to 100): maximum generation length.
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
""" # noqa: E501
# generate_length + 1 for EOS token
max_length += 1
# expand dimmension
batch_size = model_inputs["input_ids"].size(0)
input: torch.Tensor = (
model_inputs["input_ids"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
length = (
model_inputs["length"]
.squeeze(1)
.unsqueeze(1)
.expand(batch_size, beam_size)
.contiguous()
.view(
batch_size * beam_size,
)
)
span: torch.Tensor = (
model_inputs["span"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
context: torch.Tensor = (
model_inputs["context"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
done = [False for _ in range(batch_size)]
beam_scores = torch.zeros((batch_size, beam_size), dtype=torch.float, device=input.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1)
# generated hypotheses
generated_hyps = [
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(batch_size)
]
pred_start_index = input.size(-1)
past_key_values = None
for i in range(max_length + 1):
if i == 0:
logits, _, past_key_values = self.model.inference(
input=input,
context=context,
span=span,
length=length,
past_key_values=past_key_values,
)
else:
logits, _, past_key_values = self.model.inference(
input=input[:, -1:],
context=context,
span=span,
length=length,
past_key_values=past_key_values,
)
# skip all steps when we are done with each sentence
if all(done):
break
# (batch * beam, seqlen, model_dim)
logits = logits[:, -1, :]
if i == 0:
logits[:, self.tokenizer.bos_token_id] = -float("inf")
# logits[:, self.tokenizer.newline_id] = -float("inf")
apply_repetition_penalty(
logits,
batch_size,
beam_size,
input,
repetition_penalty,
pred_start_index,
input.size(-1) - 1,
repetition_window,
)
scores = F.log_softmax(logits, dim=-1)
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(batch_size, -1) # (batch_size, beam_size * vocab_size)
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * beam_size)
next_batch_beam = []
for sent_id in range(batch_size):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
if done[sent_id]:
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
# get beam and word IDs
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
word_id = idx % scores.size(-1)
# end of sentence, or next word
if word_id == self.tokenizer.bos_token_id or i == max_length:
generated_hyps[sent_id].add(
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
value.item(),
)
else:
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
# the beam for next step is full
if len(next_sent_beam) == beam_size:
break
# update next beam content
assert len(next_sent_beam) == 0 if i == max_length else beam_size
if len(next_sent_beam) == 0:
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == beam_size * (sent_id + 1)
# we have reached the last step
if i == max_length:
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * beam_size
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input.new([x[1] for x in next_batch_beam])
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
# re-order batch and internal states
input = input[beam_idx, :]
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
for key_value_layer in past_key_values["buffer"]:
if key_value_layer is not None:
key_value_layer[0] = key_value_layer[0][beam_idx]
key_value_layer[1] = key_value_layer[1][beam_idx]
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
context = torch.cat(
[context, context[:, -1:]],
dim=-1,
)
length += 1
span = torch.cat([span, span[:, -1:]], dim=-1)
# select the best hypotheses
results = []
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
results.append(best_hyp)
result_text = list(map(self.tokenizer.decode, results))
return result_text
class CPM9GRandomSampling(CPM9GGeneration):
def _decode(
self,
model_inputs,
max_length=100,
top_k=0,
top_p=0.9,
temperature=0.9,
repetition_penalty=1.0,
repetition_window=None,
**kwargs,
):
"""
Top-k and top-p sampling.
Args:
model_inputs (dict): input ids
generate_length (int, optional, defaults to 100): maximum generation length
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
""" # noqa: E501
# generate_length + 1 for EOS token
max_length += 1
input = model_inputs["input_ids"]
context = model_inputs["context"]
length = model_inputs["length"].squeeze(1)
span = model_inputs["span"]
batch_size = input.size(0)
pred_start_index = input.size(-1)
past_key_values = None
done = [False for _ in range(batch_size)]
results = [None for _ in range(batch_size)]
for i in range(max_length):
if i == 0:
logits, _, past_key_values = self.model.inference(
input=input,
context=context,
length=length,
span=span,
past_key_values=past_key_values,
)
else:
logits, _, past_key_values = self.model.inference(
input=input[:, -1:],
context=context,
length=length,
span=span,
past_key_values=past_key_values,
)
logits = logits[:, -1, :]
if i == 0:
logits[:, self.tokenizer.bos_token_id] = -float("inf")
# logits[:, self.tokenizer.newline_id] = -float("inf")
apply_repetition_penalty(
logits,
batch_size,
1,
input,
repetition_penalty,
pred_start_index,
input.size(-1) - 1,
repetition_window,
)
logits = logits / temperature
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
for idx in range(batch_size):
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
done[idx] = True
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
if sum(done) == batch_size:
break
# update input ids
input = torch.cat([input, next_token], dim=-1)
length += 1
context = torch.cat(
[context, context[:, -1:]],
dim=-1,
)
span = torch.cat(
[span, span[:, -1:]],
dim=-1,
)
result_text = list(map(self.tokenizer.decode, results))
return result_text
class CPM9GBeamSearchNBCE(CPM9GGeneration):
def _decode(
self,
model_inputs,
beam_size=5,
max_length=100,
repetition_penalty=1.0,
repetition_window=None,
):
"""
Beam search
Args:
model_inputs (dict): input ids.
beam_size (int, optional, defaults to 3): beam size of beam search.
generate_length (int, optional, defaults to 100): maximum generation length.
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
""" # noqa: E501
# generate_length + 1 for EOS token
max_length += 1
# expand dimmension
batch_size = model_inputs["input_ids"].size(0)
input: torch.Tensor = (
model_inputs["input_ids"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
length = (
model_inputs["length"]
.squeeze(1)
.unsqueeze(1)
.expand(batch_size, beam_size)
.contiguous()
.view(
batch_size * beam_size,
)
)
span: torch.Tensor = (
model_inputs["span"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
context: torch.Tensor = (
model_inputs["context"]
.unsqueeze(1)
.expand(batch_size, beam_size, -1)
.contiguous()
.view(batch_size * beam_size, -1)
)
done = [False]
beam_scores = torch.zeros((1, beam_size), dtype=torch.float, device=input.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1)
# generated hypotheses
generated_hyps = [
BeamHypotheses(beam_size, max_length, length_penalty=1, early_stopping=False) for _ in range(1)
]
pred_start_index = input.size(-1)
past_key_values = None
for i in range(max_length + 1):
if i == 0:
logits, _, past_key_values = self.model.inference(
input=input,
context=context,
span=span,
length=length,
past_key_values=past_key_values,
)
else:
logits, _, past_key_values = self.model.inference(
input=input[:, -1:],
context=context,
span=span,
length=length,
past_key_values=past_key_values,
)
# skip all steps when we are done with each sentence
if all(done):
break
# (batch * beam, seqlen, model_dim)
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
logits = NBCE(logits) # [vocab_size]
logits = logits.tile(beam_size, 1)
if i == 0:
logits[:, self.tokenizer.bos_token_id] = -float("inf")
apply_repetition_penalty(
logits,
1,
beam_size,
input,
repetition_penalty,
pred_start_index,
input.size(-1) - 1,
repetition_window,
)
scores = F.log_softmax(logits, dim=-1)
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * beam_size, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(1, -1) # (batch_size, beam_size * vocab_size)
next_scores, next_words = torch.topk(next_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (1, 2 * beam_size)
next_batch_beam = []
for sent_id in range(1):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item(), i)
if done[sent_id]:
next_batch_beam.extend([(0, 0, 0)] * beam_size) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
# get beam and word IDs
beam_id = torch.div(idx, scores.size(-1), rounding_mode="floor")
word_id = idx % scores.size(-1)
# end of sentence, or next word
if word_id == self.tokenizer.bos_token_id or i == max_length:
generated_hyps[sent_id].add(
input[sent_id * beam_size + beam_id, pred_start_index:].clone().cpu().tolist(),
value.item(),
)
else:
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
# the beam for next step is full
if len(next_sent_beam) == beam_size:
break
# update next beam content
assert len(next_sent_beam) == 0 if i == max_length else beam_size
if len(next_sent_beam) == 0:
next_sent_beam = [(0, 0, 0)] * beam_size # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == beam_size * (sent_id + 1)
# we have reached the last step
if i == max_length:
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * beam_size
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input.new([x[1] for x in next_batch_beam])
beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=input.device).long()
beam_idx *= batch_size
# re-order batch and internal states
input = input[beam_idx, :]
past_key_values["buffer"] = [list(each) if each is not None else each for each in past_key_values["buffer"]] # type: ignore # noqa: E501
for key_value_layer in past_key_values["buffer"]:
if key_value_layer is not None:
key_value_layer[0] = key_value_layer[0][beam_idx]
key_value_layer[1] = key_value_layer[1][beam_idx]
input = torch.cat([input, beam_words.unsqueeze(1)], dim=-1)
context = torch.cat(
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
dim=-1,
)
length = past_key_values["buffer_length"]
length = torch.cat(
[length, torch.ones((length.size(0), 1), dtype=torch.int16, device=length.device)],
dim=-1,
)
span = torch.cat([span, span[:, -1:]], dim=-1)
# select the best hypotheses
results = []
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
results.append(best_hyp)
result_text = list(map(self.tokenizer.decode, results))
return result_text
class CPM9GRandomSamplingNBCE(CPM9GGeneration):
def _decode(
self,
model_inputs,
max_length=100,
top_k=0,
top_p=0.9,
temperature=0.9,
repetition_penalty=1.0,
repetition_window=None,
**kwargs,
):
"""
Top-k and top-p sampling.
Args:
model_inputs (dict): input ids
generate_length (int, optional, defaults to 100): maximum generation length
top_k (int, optional, defaults to 0): keep only top k tokens with highest probability. 0 means keeping all tokens.
top_p (int, optional, defaults to 0.9): keep the top tokens with cumulative probability >= top_p.
temperature (int, optional, defaults to 0.9): the value that can cool down the logits distribution.
repetition_penalty (float, optional, defaults to 1.0): repetition penalty coefficient, 1.0 means no penalty.
repetition_window (int, optional, defaults to None): window size of repetition penalty, None means that all output tokens are penalized.
""" # noqa: E501
# generate_length + 1 for EOS token
max_length += 1
input = model_inputs["input_ids"]
context = model_inputs["context"]
length = model_inputs["length"].squeeze(1)
span = model_inputs["span"]
batch_size = input.size(0)
pred_start_index = input.size(-1)
past_key_values = None
done = [False]
results = [None]
for i in range(max_length):
if i == 0:
logits, _, past_key_values = self.model.inference(
input=input,
context=context,
length=length,
span=span,
past_key_values=past_key_values,
)
else:
logits, _, past_key_values = self.model.inference(
input=input[:, -1:],
context=context,
length=length,
span=span,
past_key_values=past_key_values,
)
assert logits.size(0) > 1, "nbce needs to ensure that the length of logits 0 is greater than 1"
logits = NBCE(logits) # [vocab_size]
logits = logits[None]
if i == 0:
logits[:, self.tokenizer.bos_token_id] = -float("inf")
# logits[:, self.tokenizer.newline_id] = -float("inf")
apply_repetition_penalty(
logits,
1,
1,
input,
repetition_penalty,
pred_start_index,
input.size(-1) - 1,
repetition_window,
)
logits = logits / temperature
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
for idx in range(1):
if not done[idx] and (next_token[idx].item() == self.tokenizer.bos_token_id or i == max_length - 1):
done[idx] = True
results[idx] = input[idx, pred_start_index:].clone().cpu().tolist() # type: ignore # noqa: E501
if sum(done) == 1:
break
next_token = next_token.tile(batch_size, 1)
# update input ids
input = torch.cat([input, next_token], dim=-1)
length = past_key_values["buffer_length"]
length = torch.cat(
[length, torch.ones((length.size(0), 1), dtype=torch.int32, device=length.device)],
dim=-1,
)
# length += 1
context = torch.cat(
[context, torch.zeros((context.size(0), 1), dtype=torch.int16, device=context.device)],
dim=-1,
)
span = torch.cat(
[span, torch.zeros((span.size(0), 1), dtype=torch.int32, device=span.device)],
dim=-1,
)
result_text = list(map(self.tokenizer.decode, results))
return result_text

View File

@ -1,3 +0,0 @@
from .cpm9g import CPM9G
from .cpm9g import CPM9GConfig
from .cpm9g_torch import CPM9GTorch

View File

@ -1,272 +0,0 @@
from typing import List
from typing import Optional
from typing import Tuple
import bmtrain as bmt
import torch
import torch.nn.functional as F
from typing_extensions import TypedDict
from ...layers import Embedding
from ...layers import Encoder
from ...layers import RotaryEmbeddingESM
from ...utils import Config
from ...utils import gradient_shrink
class CPM9GInferenceState(TypedDict):
buffer_context: torch.Tensor
buffer_sample_ids: torch.Tensor
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
class CPM9GConfig(Config):
def __init__(
self,
vocab_size=32000,
dim_model=4096,
num_heads=32,
num_kv_heads=32,
dim_head=128,
dim_ff=11008,
num_layers=32,
dropout_p=0.0,
activate_fn="silu",
scale=True,
eps=1e-5,
half: bool = True,
bf16: bool = False,
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
use_flash_attn: bool = True,
flash_attn_mask_shape="1d",
flash_impl="cuda",
base=10000,
tp=0,
disabled_checkpoint=None,
):
super().__init__()
self.vocab_size = vocab_size
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.dim_head = dim_head
self.dim_ff = dim_ff
self.num_layers = num_layers
self.dropout_p = dropout_p
self.activate_fn = activate_fn
self.scale = scale
self.eps = eps
if half:
if bf16:
self.dtype = torch.bfloat16
else:
self.dtype = torch.half
else:
self.dtype = torch.float
self.flash_impl = flash_impl
self.mask_modules = mask_modules
self.use_flash_attn = use_flash_attn
self.flash_attn_mask_shape = flash_attn_mask_shape
self.base = base
self.tp = tp
self.disabled_checkpoint = disabled_checkpoint
class CPM9G(bmt.DistributedModule):
def __init__(self, config: CPM9GConfig):
super().__init__()
self.encoder = Encoder(
num_layers=config.num_layers,
dim_model=config.dim_model,
dim_ff=config.dim_ff,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
dim_head=config.dim_head,
activate_fn=config.activate_fn,
dtype=config.dtype,
eps=config.eps,
dropout_p=config.dropout_p,
scale=config.scale,
mask_modules=config.mask_modules,
use_flash_attn=config.use_flash_attn,
tp=config.tp,
disabled_checkpoint=config.disabled_checkpoint,
)
self.input_embedding = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
scale=config.scale,
dtype=config.dtype,
init_std=0.02,
)
self.position_bias = RotaryEmbeddingESM(
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
)
self.lm_head = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
scale=config.scale,
dtype=config.dtype,
init_std=0.02,
)
self.flash_impl = config.flash_impl
self.use_flash_attn = config.use_flash_attn
self.flash_attn_mask_shape = config.flash_attn_mask_shape
self.config = config
def forward(
self,
input: torch.Tensor, # (batch, seqlen) int32
length: torch.Tensor = None, # (batch) int32
context: torch.Tensor = None, # (batch, seqlen) bool
span: torch.Tensor = None, # (batch, seqlen) int32
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
max_seqlen: int = None,
position_ids: torch.Tensor = None, # (batch, seqlen) int32
):
batch = input.size(0)
seqlen = input.size(1)
device = input.device
if length is not None and length.dim() == 1:
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
# processing masks and position bias bucket
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
with torch.no_grad():
# directional mask
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
-1, 1
)
# context mask
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
# span mask
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
# length mask
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
else:
attention_mask = None
hidden_states = self.input_embedding(input)
if self.config.tp:
with torch.no_grad():
if length is not None:
length = bmt.distributed.all_gather(length, comm=bmt.config["tp_comm"]).flatten(0, 1)
if attention_mask is not None:
attention_mask = bmt.distributed.all_gather(attention_mask, comm=bmt.config["tp_comm"]).flatten(
0, 1
)
if cu_seqlens is not None:
lens = bmt.distributed.all_gather(
torch.tensor([cu_seqlens.numel()]).cuda(), comm=bmt.config["tp_comm"]
)
mx = lens.max().item()
cu_seq = torch.zeros(mx).int().cuda()
cu_seq[: cu_seqlens.numel()] = cu_seqlens
all_seq = bmt.distributed.all_gather(cu_seq, comm=bmt.config["tp_comm"])
cu_seqlens = [0]
for i, l in enumerate(lens):
for j in range(1, l - 1):
cu_seqlens.append(all_seq[i][j] + i * seqlen)
cu_seqlens.append((i + 1) * seqlen)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).cuda()
if max_seqlen is not None:
max_seqlen = bmt.distributed.all_reduce(
torch.tensor([max_seqlen]).cuda(), "max", comm=bmt.config["tp_comm"]
)[0].item()
if position_ids is not None:
position_ids = bmt.distributed.all_gather(position_ids, comm=bmt.config["tp_comm"]).flatten(0, 1)
if self.use_flash_attn:
if self.flash_attn_mask_shape == "1d":
hidden_states = self.encoder(
hidden_states,
attention_mask=None,
position_bias=self.position_bias,
pos_bias_type="rotary",
length_mask=length,
)
else:
if self.flash_impl == "triton":
mask = attention_mask.unsqueeze(dim=1).contiguous()
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
attention_mask_bias[mask == False] -= torch.inf
else:
attention_mask_bias = None
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
hidden_states = self.encoder(
hidden_states,
attention_mask=None,
position_bias=self.position_bias,
pos_bias_type="rotary",
length_mask=None,
attention_mask_bias=attention_mask_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
else:
hidden_states = self.encoder(
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
)
logits = self.lm_head.projection(hidden_states)
return logits, hidden_states
def inference(
self,
input: torch.Tensor, # (batch, len_q) int32
length: torch.Tensor, # (batch) int32
context: torch.Tensor, # (batch, seqlen) int16
span: torch.Tensor, # (batch, seqlen) int32
past_key_values: Optional[CPM9GInferenceState] = None,
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
batch = input.size(0)
len_q = input.size(1)
len_buffer = 0
if past_key_values is None:
present_buffer = None
else:
present_buffer = past_key_values["buffer"]
len_buffer = present_buffer[0][0].shape[-2]
seqlen = len_buffer + len_q
with torch.no_grad():
device = input.device
if length.dim() == 1:
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
# context mask
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
# span mask
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
# length mask
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
hidden_states = self.input_embedding(input)
hidden_states, present_key_values, _ = self.encoder(
hidden_states,
attention_mask=attention_mask[:, len_buffer:],
position_bias=self.position_bias,
use_cache=True,
past_key_values=present_buffer,
pos_bias_type="rotary",
)
logits = self.lm_head.projection(hidden_states)
return (
logits,
hidden_states,
{"buffer": present_key_values},
)

View File

@ -1,186 +0,0 @@
from typing import List
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from typing_extensions import TypedDict
from ...native_layers import Embedding
from ...native_layers import Encoder
from ...native_layers import RotaryEmbeddingESM
from ...utils import Config
from ...utils import gradient_shrink
from .cpm9g import CPM9GConfig
class CPM9GInferenceState(TypedDict):
buffer_context: torch.Tensor
buffer_sample_ids: torch.Tensor
buffer: List[Tuple[torch.Tensor, torch.Tensor]]
class CPM9GTorch(torch.nn.Module):
def __init__(self, config: CPM9GConfig):
super().__init__()
self.encoder = Encoder(
num_layers=config.num_layers,
dim_model=config.dim_model,
dim_ff=config.dim_ff,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
dim_head=config.dim_head,
activate_fn=config.activate_fn,
dtype=config.dtype,
eps=config.eps,
dropout_p=config.dropout_p,
scale=config.scale,
mask_modules=config.mask_modules,
use_flash_attn=config.use_flash_attn,
)
self.input_embedding = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
scale=config.scale,
dtype=config.dtype,
init_std=0.02,
)
self.position_bias = RotaryEmbeddingESM(
dim=config.dim_head, dtype=config.dtype, base=config.base, persistent=False, mixed_precision=True
)
self.lm_head = Embedding(
vocab_size=config.vocab_size,
embedding_size=config.dim_model,
scale=config.scale,
dtype=config.dtype,
init_std=0.02,
)
self.flash_impl = False
self.use_flash_attn = False
self.flash_attn_mask_shape = "1d"
def forward(
self,
input: torch.Tensor, # (batch, seqlen) int32
length: torch.Tensor = None, # (batch) int32
context: torch.Tensor = None, # (batch, seqlen) bool
span: torch.Tensor = None, # (batch, seqlen) int32
cu_seqlens: torch.Tensor = None, # (real_batch+2) int32
max_seqlen: int = None,
position_ids: torch.Tensor = None, # (batch, seqlen) int32
):
batch = input.size(0)
seqlen = input.size(1)
device = input.device
if length is not None and length.dim() == 1:
length = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
# processing masks and position bias bucket
if not self.use_flash_attn or (self.flash_attn_mask_shape == "2d" and self.flash_impl == "triton"):
with torch.no_grad():
# directional mask
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
-1, 1
)
# context mask
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
# span mask
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
# length mask
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
hidden_states = self.input_embedding(input)
if self.use_flash_attn:
if self.flash_attn_mask_shape == "1d":
hidden_states = self.encoder(
hidden_states,
attention_mask=None,
position_bias=self.position_bias,
pos_bias_type="rotary",
length_mask=length,
)
else:
if self.flash_impl == "triton":
mask = attention_mask.unsqueeze(dim=1).contiguous()
attention_mask_bias = torch.zeros_like(mask, device="cuda", dtype=torch.float16)
attention_mask_bias[mask == False] -= torch.inf
else:
attention_mask_bias = None
assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl"
hidden_states = self.encoder(
hidden_states,
attention_mask=None,
position_bias=self.position_bias,
pos_bias_type="rotary",
length_mask=None,
attention_mask_bias=attention_mask_bias,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
position_ids=position_ids,
)
else:
hidden_states = self.encoder(
hidden_states, attention_mask=attention_mask, position_bias=self.position_bias, pos_bias_type="rotary"
)
logits = self.lm_head.projection(hidden_states)
return logits, hidden_states
def inference(
self,
input: torch.Tensor, # (batch, len_q) int32
length: torch.Tensor, # (batch) int32
context: torch.Tensor, # (batch, seqlen) int16
span: torch.Tensor, # (batch, seqlen) int32
past_key_values: Optional[CPM9GInferenceState] = None,
) -> Tuple[torch.Tensor, torch.Tensor, CPM9GInferenceState]:
batch = input.size(0)
len_q = input.size(1)
len_buffer = 0
if past_key_values is None:
present_buffer = None
else:
present_buffer = past_key_values["buffer"]
len_buffer = present_buffer[0][0].shape[-2]
seqlen = len_buffer + len_q
with torch.no_grad():
device = input.device
if length.dim() == 1:
length = (torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) + length[:, None]) >= seqlen
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
# context mask
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
# span mask
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
# length mask
attention_mask = length.view(batch, seqlen, 1) & length.view(batch, 1, seqlen) & attention_mask
hidden_states = self.input_embedding(input)
hidden_states, present_key_values, _ = self.encoder(
hidden_states,
attention_mask=attention_mask[:, len_buffer:],
position_bias=self.position_bias,
use_cache=True,
past_key_values=present_buffer,
pos_bias_type="rotary",
)
logits = self.lm_head.projection(hidden_states)
return (
logits,
hidden_states,
{"buffer": present_key_values},
)

View File

@ -1 +0,0 @@
from .cpm9g import CPM9GTokenizer

View File

@ -1,2 +0,0 @@
from .pretrain import MixedDataset
from .finetune import FinetuneDataset

View File

@ -1,87 +0,0 @@
import bmtrain as bmt
from ...dataset import SimpleDataset
from ..tokenizers import CPM9GTokenizer
from .pretrain import _MixedDatasetBatchPacker
from .pretrain import _MixedDatasetConfig
from .pretrain import CPM9GBatch
class FinetuneDataset:
def __init__(
self,
dataset_path: str,
batch_size: int,
max_length: int,
tokenizer: CPM9GTokenizer,
unpad: bool = False,
task_name: str = "task",
drop_last: bool = False,
) -> None:
self._world_size = bmt.world_size()
self._rank = bmt.rank()
self._batch_size = batch_size
self._unpad = unpad
self._max_length = max_length
self._packer = _MixedDatasetBatchPacker(batch_size * self._world_size, max_length, tokenizer, unpad)
self._drop_last = drop_last
ds = SimpleDataset(dataset_path, shuffle=False)
self._ds_cfg: _MixedDatasetConfig = {
"weight": 1.0,
"path": dataset_path,
"transforms": [],
"task_name": task_name,
"dataset_name": "finetune",
"incontext_weight": [1.0],
"lines": len(ds),
"dataset": ds,
}
self._nlines = len(ds)
self._nbytes = ds.get_bytes()
def __batch_iter(self):
while True:
try:
batch = self._packer.add_data(self._ds_cfg)
except EOFError:
break
if batch is None:
continue
yield batch
if len(self._packer) > 0:
batch = self._packer.pack_batch(force=True)
if not self._drop_last:
yield batch
self._ds_cfg["dataset"]._repeat_times = 0
def __iter__(self):
if not self._unpad:
batch_st = self._batch_size * self._rank
batch_end = self._batch_size * (self._rank + 1)
for batch in self.__batch_iter():
batch_size = batch["inputs"].shape[0]
if batch_size <= batch_st:
yield None
else:
ret: CPM9GBatch = {
kw: val[batch_st:batch_end] # type: ignore
for kw, val in batch.items()
if kw not in ["task_names", "raw_data", "cu_seqlens", "max_seqlen"]
} # type: ignore
ret["task_names"] = batch["task_names"]
ret["raw_data"] = batch["raw_data"]
ret["cu_seqlens"] = batch["cu_seqlens"]
ret["max_seqlen"] = batch["max_seqlen"]
yield ret
else:
for batch in self.__batch_iter():
assert batch["inputs"].shape[0] == 1
yield batch
def __len__(self):
iter_count = int(self._nbytes // (3.6 * self._batch_size * self._world_size * self._max_length))
if not self._drop_last:
iter_count += 1
return iter_count

View File

@ -1,736 +0,0 @@
import importlib.machinery
import importlib.util
import json
import multiprocessing
import os
import random
import time
import types
from collections import OrderedDict
from queue import Empty
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from numpy.typing import NDArray
from typing_extensions import TypedDict
from ...dataset import DistributedDataset
from ..tokenizers import CPM9GTokenizer
class _MixedDatasetConfig(TypedDict):
weight: float
path: str
transforms: Union[List[Dict[str, Any]], str]
task_name: str
dataset_name: str
incontext_weight: List[float]
lines: int
dataset: DistributedDataset
CPM9GInputType = Union[str, Dict[str, "CPM9GInputType"]]
class _TransformFuncDict(TypedDict):
loader: importlib.machinery.SourceFileLoader
module: types.ModuleType
last_m: float
_TransformFunction = Callable[[CPM9GInputType, int, random.Random], CPM9GInputType]
class CPM9GBatch(TypedDict):
inputs: NDArray[np.int32]
length: NDArray[np.int32]
context: NDArray[np.bool_]
sample_ids: NDArray[np.int32]
spans: NDArray[np.int32]
target: NDArray[np.int32]
task_ids: NDArray[np.int32]
task_names: List[str]
raw_data: List[Any]
def convert_data_to_id(tokenizer: CPM9GTokenizer, data: Any):
input_ids = tokenizer.encode(data["input"])
output_ids = tokenizer.encode(data["output"])
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
ids = np.array(ids, dtype=np.int32)
context = np.zeros((ids.shape[0],), dtype=np.int8)
context[: len(input_ids) + 1] = 1
return ids, context
def _dataset_identity(c: _MixedDatasetConfig):
return "{}.{}".format(c["task_name"], c["dataset_name"])
class _MixedDatasetBatchPacker:
def __init__(
self,
batch_size: int,
max_length: int,
tokenizer: CPM9GTokenizer,
unpad: bool = False,
) -> None:
self._batch_size = batch_size
self._max_length = max_length
self._clip_length = max_length
self.tokenizer = tokenizer
self._transform_func_table: Dict[str, _TransformFuncDict] = {}
self._unpad = unpad
if unpad:
self._max_length = max_length * batch_size
self._batch_size = 1
self._inputs: List[NDArray[np.int32]] = []
self._context: List[NDArray[np.int8]] = []
self._sample_ids: List[NDArray[np.int32]] = []
self._spans: List[List[int]] = []
self._task_ids: List[List[str]] = []
self._raw_data: List[List[Any]] = []
def __len__(self):
return len(self._inputs)
def apply_transform(
self,
data: CPM9GInputType,
transform: Union[Dict[str, Any], Callable[[CPM9GInputType], CPM9GInputType], None],
) -> CPM9GInputType:
if transform is None:
return data
if not isinstance(transform, dict):
return transform(data) # transform function
mapping_list: List[Tuple[str, str]] = []
def _walk_transform_dict(data: Union[Dict[str, Any], str], prefix: str = ""):
if isinstance(data, dict):
for k, v in data.items():
if len(prefix) > 0:
_walk_transform_dict(v, prefix + "." + k)
else:
_walk_transform_dict(v, k)
else:
assert isinstance(data, str), "Invalid transform {}".format(data)
mapping_list.append((prefix, data))
_walk_transform_dict(transform)
expanded_mapping_list: List[Tuple[str, Any]] = []
def _expand_mapping(data: CPM9GInputType, stars: List[str], path: List[str], target: List[str]):
if len(path) == 0:
num_stars = 0
for it in target:
if it == "*":
num_stars += 1
if num_stars != len(stars):
raise ValueError("Invalid transform {}".format(".".join(target)))
nw_tgt = []
num_stars = 0
for it in target:
if it == "*":
nw_tgt.append(stars[num_stars])
num_stars += 1
else:
nw_tgt.append(it)
expanded_mapping_list.append((".".join(nw_tgt), data))
else:
if not isinstance(data, dict):
raise ValueError("Invalid data {}".format(data))
if path[0] == "*":
for k, v in data.items():
_expand_mapping(v, stars + [k], path[1:], target)
else:
_expand_mapping(data[path[0]], stars, path[1:], target)
# expand mapping list
for tgt, src in mapping_list:
if src.startswith("$"):
# copy from src
_expand_mapping(data, [], src[1:].split("."), tgt.split("."))
else:
if "*" in tgt:
raise ValueError("Constant value is not allowed to have `*` in prefix")
expanded_mapping_list.append((tgt, src))
ret = {}
for tgt, val in expanded_mapping_list:
tgt = tgt.split(".")
cur = ret
while len(tgt) > 1:
cur = cur[tgt[0]]
tgt = tgt[1:]
cur[tgt[0]] = val
return ret
def data_to_id(self, data: Any):
return convert_data_to_id(self.tokenizer, data)
def _ensure_transform_function(self, module_name: str, transform_script_path: str) -> _TransformFunction:
module_name = "cpm_live.transforms.{}".format(module_name)
if transform_script_path not in self._transform_func_table:
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
spec = importlib.util.spec_from_loader(loader.name, loader)
if spec is None:
raise RuntimeError("spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self._transform_func_table[transform_script_path] = {
"loader": loader,
"module": mod,
"last_m": 0,
}
transform_script_info = self._transform_func_table[transform_script_path]
curr_m_time = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
if curr_m_time > transform_script_info["last_m"]:
transform_script_info["last_m"] = curr_m_time
transform_script_info["loader"].exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
def _empty_transform_func(data: CPM9GInputType, num_sample: int, r: random.Random):
raise NotImplementedError("Transform func for dataset {} not implemented".format(module_name))
return _empty_transform_func
else:
return transform_func
def build_instance(self, config: _MixedDatasetConfig):
_sample_weight = np.array(config["incontext_weight"], dtype=np.float32)
_sample_weight = _sample_weight / _sample_weight.sum()
num_incontext = np.random.choice(_sample_weight.shape[0], p=_sample_weight)
ds = config["dataset"]
transforms = config["transforms"]
if isinstance(transforms, str):
if not os.path.exists(transforms):
raise RuntimeError("transform script {} file not exists".format(transforms))
# load transform script
transform_func = self._ensure_transform_function(_dataset_identity(config), transforms)
seed = random.random()
def _transform(data: CPM9GInputType):
r = random.Random(seed)
return transform_func(data, num_incontext, r)
transform = _transform
elif len(transforms) == 0:
transform = None
else:
transform = transforms[np.random.choice(len(transforms))]
raw_data = {}
while True:
inp = ds.read()
inp = self.apply_transform(inp, transform)
# if None, skip this one
if inp is None:
continue
input_ids, context = self.data_to_id(inp)
if input_ids.shape[0] > self._clip_length:
continue # too long
input_ids = input_ids[: self._clip_length]
context = context[: self._clip_length]
raw_data["input"] = inp
raw_data["samples"] = []
break
sample_ids = np.zeros(input_ids.shape, dtype=np.int32)
i = 0
while True:
if i == num_incontext or input_ids.shape[0] >= self._clip_length:
break # early break
sample = ds.read()
sample = self.apply_transform(sample, transform)
# if sample is None, skip this one
if sample is None:
continue
sample_input_ids, _ = self.data_to_id(sample)
if input_ids.shape[0] + sample_input_ids.shape[0] > self._clip_length:
break # too long, break
raw_data["samples"].append(sample)
input_ids = np.concatenate([input_ids, sample_input_ids], axis=0)
context = np.concatenate([context, np.ones(sample_input_ids.shape, dtype=np.int8)], axis=0)
sample_ids = np.concatenate([sample_ids, np.full(sample_input_ids.shape, i + 1, dtype=np.int32)], axis=0)
i += 1
return (
input_ids,
context,
sample_ids,
raw_data,
)
def pack_batch(self, force: bool = False, unilm=False) -> CPM9GBatch:
# pack batch
if len(self._inputs) < self._batch_size:
if not force:
raise RuntimeError("Batch insufficient")
batch_size = len(self._inputs)
else:
batch_size = self._batch_size
max_length = self._max_length # self._spans[0][-1] if self._unpad else self._max_length
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
context = np.zeros((batch_size, max_length), dtype=np.int8)
sample_ids = np.zeros((batch_size, max_length), dtype=np.int32)
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
spans = np.zeros((batch_size, max_length), dtype=np.int32)
length = np.zeros((batch_size,), dtype=np.int32)
task_ids = np.full((batch_size, max_length), -1, dtype=np.int32)
if self._spans[0][-1] != max_length:
cu_seqlens = np.array([0] + self._spans[0] + [max_length], dtype=np.int32)
else:
cu_seqlens = np.array([0] + self._spans[0], dtype=np.int32)
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
all_task_names: Set[str] = set()
for i in range(batch_size):
for task_name in self._task_ids[i]:
all_task_names.add(task_name)
task_names: List[str] = list(all_task_names)
task_name_to_id = {name: i for i, name in enumerate(task_names)}
raw_data_list: List[Any] = []
for i in range(batch_size):
instance_length = self._inputs[i].shape[0]
inputs[i, :instance_length] = self._inputs[i]
sample_ids[i, :instance_length] = self._sample_ids[i]
if unilm:
context[i, :instance_length] = self._context[i]
span_begin = 0
for span_id, (span_end, task_name) in enumerate(zip(self._spans[i], self._task_ids[i])):
spans[i, span_begin:span_end] = span_id
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
task_ids[i, span_begin:span_end] = task_name_to_id[task_name]
span_begin = span_end
length[i] = instance_length
raw_data_list.extend(self._raw_data[i])
for j in range(instance_length):
if j > 1 and self._context[i][j] == 0:
if self._inputs[i][j] != self.tokenizer.bos_id and self._inputs[i][j - 1] != self.tokenizer.eos_id:
tgt[i, j - 1] = self._inputs[i][j]
self._inputs = self._inputs[batch_size:]
self._context = self._context[batch_size:]
self._sample_ids = self._sample_ids[batch_size:]
self._spans = self._spans[batch_size:]
self._task_ids = self._task_ids[batch_size:]
self._raw_data = self._raw_data[batch_size:]
return {
"inputs": inputs,
"length": length,
"context": context > 0,
"sample_ids": sample_ids,
"spans": spans,
"cu_seqlens": cu_seqlens,
"max_seqlen": max_seqlen,
"position_ids": position_ids,
"target": tgt,
"task_ids": task_ids,
"task_names": task_names,
"raw_data": raw_data_list,
}
def add_data(self, config: _MixedDatasetConfig) -> Optional[CPM9GBatch]:
(
input_ids,
context,
sample_ids,
raw_data,
) = self.build_instance(config)
# add to batch
best_fit: Union[None, int] = None
best_fit_space: Union[None, int] = None
for i in range(len(self._inputs)):
space = self._max_length - self._inputs[i].shape[0]
if input_ids.shape[0] <= space:
if best_fit_space is None:
best_fit = i
best_fit_space = space
elif best_fit_space > space:
best_fit = i
best_fit_space = space
if best_fit is None:
# add a new instance
self._inputs.append(input_ids)
self._context.append(context)
self._sample_ids.append(sample_ids)
self._spans.append([input_ids.shape[0]])
self._task_ids.append([config["task_name"]])
self._raw_data.append([raw_data])
else:
# add to existing instance
self._inputs[best_fit] = np.concatenate([self._inputs[best_fit], input_ids], axis=0)
self._context[best_fit] = np.concatenate([self._context[best_fit], context], axis=0)
self._sample_ids[best_fit] = np.concatenate([self._sample_ids[best_fit], sample_ids], axis=0)
self._spans[best_fit].append(self._inputs[best_fit].shape[0])
self._task_ids[best_fit].append(config["task_name"])
self._raw_data[best_fit].append(raw_data)
if len(self._inputs) > self._batch_size:
return self.pack_batch()
else:
return None # not ready
class _MixedDatasetConfigMananger:
def __init__(self, config_path: str) -> None:
self._config_path: str = config_path
self._config: Union[List[_MixedDatasetConfig], None] = None
self._last_m = 0
def changed(self):
while True:
try:
m_time = os.stat(self._config_path).st_mtime
if m_time > self._last_m:
# try to load new config
try:
self._config = json.load(open(self._config_path, "r", encoding="utf-8"))
except Exception:
# failed to load config
return False
# new config loaded
self._last_m = m_time
return True
return False
except Exception:
print("Error: reading info list! {}".format(self._config_path))
time.sleep(30)
def get_config(self) -> List[_MixedDatasetConfig]:
if self._config is None:
if not self.changed():
raise RuntimeError("Failed to load config")
if self._config is None:
raise RuntimeError("Failed to load config")
return self._config
def _mixed_dataset_process(
config_path: str,
q_cmd: multiprocessing.Queue,
q_cmd_out: multiprocessing.Queue,
q_data: multiprocessing.Queue,
rank: int,
world_size: int,
packer: _MixedDatasetBatchPacker,
max_repeat_times: int = None,
random_state=None,
):
for _ in range(rank):
np.random.random()
random.setstate(random_state)
# ignore SIGINT
import signal
signal.signal(signal.SIGINT, signal.SIG_IGN)
config_base_path = os.path.dirname(os.path.abspath(config_path))
# fix libgomp threading deadlock after fork(), use single-thread in data process.
# REF: https://github.com/pytorch/pytorch/issues/17199
torch.set_num_threads(1)
def _convert_to_abs_path(transform_path: str):
if transform_path.startswith("/"):
return transform_path
else:
return os.path.join(config_base_path, transform_path)
def _build_sample_weights(config: List[_MixedDatasetConfig]):
if len(config) == 0:
return np.array([], dtype=np.float32)
weights = [c["weight"] * c["lines"] for c in config]
weights = np.array(weights, dtype=np.float32)
sm_weight = weights.sum()
if sm_weight > 0:
weights = weights / sm_weight
return weights
else:
raise RuntimeError("Empty datasets")
cfg_mgr = _MixedDatasetConfigMananger(config_path)
config = cfg_mgr.get_config()
for c in config:
ds = DistributedDataset(
_convert_to_abs_path(c["path"]),
rank,
world_size,
max_repeat_times=max_repeat_times,
)
c["lines"] = ds._nlines
c["dataset"] = ds
if "weight" not in c:
c["weight"] = 1.0
if "transforms" not in c:
c["transforms"] = []
elif isinstance(c["transforms"], str):
c["transforms"] = _convert_to_abs_path(c["transforms"])
if "incontext_weight" not in c:
c["incontext_weight"] = [1.0]
weights = _build_sample_weights(config)
should_stop = False
should_start = False
while not should_stop:
# update config first
if cfg_mgr.changed():
path_ds_map: Dict[str, _MixedDatasetConfig] = {}
nw_path_set: Set[str] = set()
# load new config
nw_config = cfg_mgr.get_config()
# build path -> dataset map
for c in config:
path_ds_map[_dataset_identity(c)] = c
# add new datasets
for c in nw_config:
if _dataset_identity(c) in path_ds_map:
# update values only
if "weight" in c:
path_ds_map[_dataset_identity(c)]["weight"] = c["weight"]
if "transform" in c:
if isinstance(c["transforms"], str):
path_ds_map[_dataset_identity(c)]["transforms"] = _convert_to_abs_path(c["transforms"])
else:
path_ds_map[_dataset_identity(c)]["transforms"] = c["transforms"]
if "incontext_weight" in c:
path_ds_map[_dataset_identity(c)]["incontext_weight"] = c["incontext_weight"]
else:
# new dataset
ds = DistributedDataset(
_convert_to_abs_path(c["path"]),
rank,
world_size,
max_repeat_times=max_repeat_times,
)
c["lines"] = ds._nlines
c["dataset"] = ds
if "weight" not in c:
c["weight"] = 1.0
if "transforms" not in c:
c["transforms"] = []
elif isinstance(c["transforms"], str):
c["transforms"] = _convert_to_abs_path(c["transforms"])
if "incontext_weight" not in c:
c["incontext_weight"] = [1.0]
path_ds_map[_dataset_identity(c)] = c
nw_path_set.add(_dataset_identity(c))
# remove unused datasets
for c in config:
if _dataset_identity(c) not in nw_path_set:
del path_ds_map[_dataset_identity(c)]
config: List[_MixedDatasetConfig] = []
for c in nw_config:
config.append(path_ds_map[_dataset_identity(c)])
del path_ds_map
del nw_path_set
del nw_config
weights = _build_sample_weights(config)
# get cmds
while True:
try:
cmd = q_cmd.get_nowait()
except Empty:
break
if cmd == "stop":
should_stop = True
q_cmd_out.put(True)
break
elif cmd == "state_dict":
ret = OrderedDict()
for c in config:
ds_name = _dataset_identity(c)
ret[ds_name] = c["dataset"]._state_dict()
q_cmd_out.put(ret)
elif cmd == "load_state_dict":
state_dict = q_cmd.get()
missing = []
for idx, c in enumerate(config):
ds_name = _dataset_identity(c)
print(f"loading {idx}/{len(config)} {ds_name}")
if ds_name in state_dict:
c["dataset"].load_state_dict(state_dict[ds_name], strict=False)
else:
# new dataset
missing.append(ds_name)
q_cmd_out.put(missing)
elif cmd == "start":
should_start = True
q_cmd_out.put(True)
else:
raise RuntimeError("Unknown command: {}".format(cmd))
if should_stop:
break
if not should_start:
# wait for start cmd
time.sleep(1)
continue
if len(config) == 0:
# no dataset available
time.sleep(1)
continue
if q_data.full():
# queue full
time.sleep(1)
continue
# sample a dataset
ds_id: int = 0
while True:
ds_id = np.random.choice(weights.shape[0], p=weights)
if config[ds_id]["dataset"]._nlines != config[ds_id]["lines"]:
# dataset size changed
for c in config:
c["lines"] = c["dataset"]._nlines
weights = _build_sample_weights(config)
continue
else:
break
batch = packer.add_data(config[ds_id])
if batch is not None:
# new batch comming
q_data.put(batch)
# clean queue
while True:
try:
q_data.get_nowait()
except Empty:
break
class MixedDataset:
def __init__(
self,
config_path: str,
batch_size: int,
max_length: int,
tokenizer: CPM9GTokenizer,
unpad: bool = False,
max_repeat_times: int = None,
) -> None:
self._q_cmd = multiprocessing.Queue()
self._q_cmd_out = multiprocessing.Queue()
self._q_data = multiprocessing.Queue(maxsize=2)
self._packer = _MixedDatasetBatchPacker(batch_size, max_length, tokenizer, unpad)
self._p = multiprocessing.Process(
target=_mixed_dataset_process,
args=(
config_path,
self._q_cmd,
self._q_cmd_out,
self._q_data,
bmt.rank(),
bmt.world_size(),
self._packer,
max_repeat_times,
random.getstate(),
),
)
self._p.start()
self._closed = False
def close(self):
if not self._closed:
self._closed = True
self._q_cmd.put("stop")
assert self._q_cmd_out.get(), "Failed to stop process"
self._p.join()
@property
def closed(self):
return self._closed
def start(self):
self._q_cmd.put("start")
return self._q_cmd_out.get()
def state_dict(self):
self._q_cmd.put("state_dict")
states = self._q_cmd_out.get()
if not isinstance(states, OrderedDict):
raise RuntimeError("Invalid state dict {}".format(states))
if bmt.world_size() == 1:
for val in states.values():
val["states"].unsqueeze_(0)
val["block"].unsqueeze_(0)
return states
ret = OrderedDict()
for k, v in states.items():
num_unused_block = v["states"].size(0)
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
if max_unused_blocks == 0:
max_unused_blocks = 1
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
gpu_states[:num_unused_block] = v["states"].cuda()
gpu_block = v["block"].cuda()
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
ret[k] = {"states": global_states, "block": global_block}
return ret
def load_state_dict(self, data: OrderedDict, strict: bool = False):
self._q_cmd.put("load_state_dict")
self._q_cmd.put(data)
missing = self._q_cmd_out.get()
if strict:
if len(missing) > 0:
raise RuntimeError("Missing dataset state: {}".format(missing))
return missing
def get(self) -> CPM9GBatch:
ret: CPM9GBatch = self._q_data.get(timeout=300) # type: ignore
if not isinstance(ret, dict):
raise RuntimeError("Invalid data {}".format(ret))
return ret
def __iter__(self):
while True:
yield self.get()
def __del__(self):
if not self.closed:
try:
self.close()
except Exception:
pass

File diff suppressed because it is too large Load Diff

View File

@ -1,57 +0,0 @@
import itertools
import os
import random
from typing import List
import bmtrain as bmt
import torch
class ListDataset(torch.utils.data.Dataset):
"""
同时支持 map-style iterable-style
"""
def __init__(
self, data_list: List, distributed: bool = False, shuffle: bool = True, infinite: bool = False
) -> None:
super(ListDataset, self).__init__()
if distributed:
rank = bmt.rank()
world_size = bmt.world_size()
self.data_list = list(itertools.islice(data_list, rank, None, world_size))
else:
self.data_list = data_list
self.shuffle = shuffle
self.infinite = infinite
self.idx = 0
if shuffle:
self._shuffle()
def __iter__(self):
return self
def __next__(self):
if self.idx >= len(self):
if self.infinite:
if self.shuffle:
self._shuffle()
self.idx = 0
else:
raise StopIteration
data = self.data_list[self.idx]
self.idx += 1
return data
def __getitem__(self, idx):
return self.data_list[idx]
def __len__(self):
return len(self.data_list)
def _shuffle(self):
random.shuffle(self.data_list)
def read(self):
return self.__next__()

View File

@ -1,4 +0,0 @@
from .generation_utils import apply_repetition_penalty
from .generation_utils import BeamHypotheses
from .generation_utils import NBCE
from .generation_utils import top_k_top_p_filtering

View File

@ -1,127 +0,0 @@
import torch
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
batch_size = logits.size()[0]
if top_p > 0.0:
logits = logits.view(batch_size, -1).contiguous()
for index in range(len(logits)):
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[index][indices_to_remove] = filter_value
logits = logits.view(batch_size, -1).contiguous()
return logits
def apply_repetition_penalty(
logits,
batch_size,
num_beams,
prev_output_tokens,
repetition_penalty,
start_idx=None,
end_idx=None,
window_size=None,
):
# only conduct repetition penalty for the output
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
for i in range(batch_size * num_beams):
if start_idx is None or end_idx is None:
output_tokens = prev_output_tokens[i].tolist()
else:
if end_idx >= start_idx:
if window_size:
output_tokens = prev_output_tokens[i][
max(start_idx, end_idx + 1 - window_size) : end_idx + 1
].tolist()
else:
output_tokens = prev_output_tokens[i][start_idx : end_idx + 1].tolist()
else:
output_tokens = []
for previous_token in set(output_tokens):
# if score < 0 then repetition penalty has to
# multiplied to reduce the previous token probability
if logits[i, previous_token] < 0:
logits[i, previous_token] *= repetition_penalty
else:
logits[i, previous_token] /= repetition_penalty
class BeamHypotheses:
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_len = max_len
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.n_hyp = n_hyp
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.n_hyp:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.n_hyp:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
def NBCE(logits):
"""
Naive Bayes-based Context Extension
blog: https://www.kexue.fm/archives/9617
"""
beta = 0.25
logits = logits[:, -1] # bsh -> bh
logits = logits - logits.logsumexp(dim=-1, keepdims=True)
k = (logits.exp() * logits).sum(dim=-1)[1:].argmax() + 1
logits_max = logits[k]
logits_uncond = logits[0]
logits = (1 + beta) * logits_max - beta * logits_uncond
return logits

View File

@ -1,12 +0,0 @@
from .attention import Attention
from .blocks import TransformerBlock
from .embedding import Embedding
from .embedding import EmbeddingExt
from .feedforward import FeedForward
from .layernorm import LayerNorm
from .linear import Linear
from .position_embedding import BucketPositionBias
from .position_embedding import RotaryEmbedding
from .position_embedding import RotaryEmbeddingESM
from .position_embedding import SegmentPositionEmbedding
from .transformer import Encoder

View File

@ -1,134 +0,0 @@
import math
from typing import Optional
from typing import Tuple
import torch
from einops import rearrange
from .linear import Linear
class Attention(torch.nn.Module):
def __init__(
self,
dim_model: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
dtype: torch.dtype = torch.half,
dropout_p: Optional[float] = None,
use_flash_attn: bool = False,
scale: bool = True,
) -> None:
super().__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.dim_head = dim_head
self.num_kv_heads = num_kv_heads
self.head_groups = num_heads // num_kv_heads
self.project_q = Linear(self.dim_model, self.num_heads * self.dim_head, dtype=dtype, scale=scale)
self.project_k = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
self.project_v = Linear(self.dim_model, self.num_kv_heads * self.dim_head, dtype=dtype, scale=scale)
self.attention_out = Linear(self.num_heads * self.dim_head, self.dim_model, dtype=dtype, scale=scale)
self.softmax = torch.nn.Softmax(dim=-1)
if dropout_p is not None:
self.dropout = torch.nn.Dropout(p=dropout_p)
else:
self.dropout = None
# if use_flash_attn:
# self.core_attention_flash = FlashSelfAttention(causal=False, attention_dropout=0.0)
# self.use_flash_attn = use_flash_attn
def forward(
self,
hidden_q: torch.Tensor,
hidden_kv: torch.Tensor,
attention_mask: torch.BoolTensor,
position_bias: torch.Tensor,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
):
"""
Args:
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
Return:
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
""" # noqa: E501
batch_size = hidden_q.size(0)
len_q = hidden_q.size(1)
len_k = hidden_kv.size(1)
h_q = self.project_q(hidden_q)
h_k = self.project_k(hidden_kv)
h_v = self.project_v(hidden_kv)
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
h_q = h_q.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
h_k = h_k.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
h_v = h_v.view(batch_size, len_k, self.num_kv_heads, self.dim_head).permute(0, 2, 1, 3)
if pos_bias_type == "rotary":
# b h s d
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
if past_kv is not None:
h_k = torch.cat([past_kv[0], h_k], dim=-2)
h_v = torch.cat([past_kv[1], h_v], dim=-2)
len_k = h_k.size(-2)
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
if self.head_groups == 1:
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
else:
score = torch.matmul(
h_q.reshape(batch_size, self.num_kv_heads, self.head_groups * len_q, self.dim_head),
h_k.transpose(-1, -2),
).view(batch_size, self.num_heads, len_q, len_k)
if pos_bias_type == "relative":
score = score + position_bias
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
)
score = self.softmax(score)
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
)
if self.dropout is not None:
score = self.dropout(score)
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
score = torch.matmul(score.view(batch_size, self.num_kv_heads, self.head_groups * len_q, len_k), h_v).view(
batch_size, self.num_heads, len_q, self.dim_head
)
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
score = self.attention_out(score)
if use_cache:
return score, (h_k, h_v)
else:
return score

View File

@ -1,279 +0,0 @@
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from .attention import Attention
from .feedforward import FeedForward
from .layernorm import LayerNorm
from .position_embedding import RotaryEmbedding
from .position_embedding import RotaryEmbeddingESM
class SelfAttentionBlock(torch.nn.Module):
"""The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.
Args:
dim_model (int): main dimension of modules in transformer blocks.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
num_heads: int,
dim_head: int,
num_kv_heads: int,
dtype=torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
scale: bool = True,
use_flash_attn: bool = False,
):
super().__init__()
self.layernorm_before_attention = LayerNorm(
dim_model,
dtype=dtype,
eps=eps,
)
self.self_attention = Attention(
dim_model=dim_model,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
dtype=dtype,
dropout_p=dropout_p,
scale=scale,
use_flash_attn=use_flash_attn,
)
if dropout_p:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_bias: Union[torch.Tensor, RotaryEmbedding, RotaryEmbeddingESM] = None,
use_cache: bool = False,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
):
"""
Args:
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation.
position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block.
""" # noqa: E501
x = self.layernorm_before_attention(hidden_states)
x = self.self_attention(
x,
x,
attention_mask,
position_bias,
use_cache,
past_key_value,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
)
if use_cache:
x, current_key_value = x
else:
current_key_value = None
if self.dropout is not None:
x = self.dropout(x)
hidden_states = hidden_states + x
if use_cache:
return hidden_states, current_key_value
else:
return hidden_states
class FFNBlock(torch.nn.Module):
"""The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection.
Args:
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
activate_fn: str,
dtype=torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = 0,
scale: bool = True,
):
super().__init__()
self.layernorm_before_ffn = LayerNorm(
dim_model,
dtype=dtype,
eps=eps,
)
self.ffn = FeedForward(
dim_model,
dim_ff,
activate_fn=activate_fn,
dtype=dtype,
dropout_p=dropout_p,
scale=scale,
)
if dropout_p:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
def forward(
self,
hidden_states: torch.Tensor,
):
"""
Args:
hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block
""" # noqa: E501
x = self.layernorm_before_ffn(hidden_states)
x = self.ffn(x)
if self.dropout is not None:
x = self.dropout(x)
hidden_states = hidden_states + x
return hidden_states
class TransformerBlock(torch.nn.Module):
"""The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block.
Args:
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5.
dropout_p (float, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
activate_fn: str = "gelu",
dtype=torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
scale: bool = True,
mask_att: bool = False,
mask_ffn: bool = False,
use_flash_attn: bool = False,
):
super().__init__()
self.mask_att = mask_att
self.mask_ffn = mask_ffn
if not self.mask_att:
self.self_att = SelfAttentionBlock(
dim_model=dim_model,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
use_flash_attn=use_flash_attn,
)
if not self.mask_ffn:
self.ffn = FFNBlock(
dim_model=dim_model,
dim_ff=dim_ff,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
)
def forward(
self,
self_hidden_states: torch.Tensor,
self_attention_mask: torch.Tensor,
self_position_bias: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
):
"""
Args:
self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention.
self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block.
""" # noqa: E501
# (batch, dim_model, seq_self)
current_key_value = None
if not self.mask_att:
hidden_states = self.self_att(
self_hidden_states,
attention_mask=self_attention_mask,
position_bias=self_position_bias,
use_cache=use_cache,
past_key_value=past_key_value,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
)
if use_cache:
hidden_states, current_key_value = hidden_states
else:
hidden_states = self_hidden_states
# (batch, dim_model, seq_self)
if not self.mask_ffn:
hidden_states = self.ffn(hidden_states)
if use_cache:
return hidden_states, current_key_value
else:
return hidden_states

View File

@ -1,100 +0,0 @@
import math
from typing import Optional
import torch
import torch.nn.functional as F
from .position_embedding import RotaryEmbedding
class Embedding(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.half,
scale: bool = True,
init_mean: float = 0.0,
init_std: float = 1,
):
super().__init__()
self.dim_model = embedding_size
self.weight = torch.nn.parameter.Parameter(torch.empty(vocab_size, embedding_size, dtype=dtype))
self.scale = scale
def forward(self, ids: torch.Tensor):
"""
Args:
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
""" # noqa: E501
if self.scale:
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
else:
embeds = F.embedding(ids, self.weight)
return embeds.clone()
def projection(self, x: torch.Tensor):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output.
""" # noqa: E501
if self.scale:
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
else:
logits = F.linear(x, self.weight)
return logits
class EmbeddingExt(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_size: int,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
distance_scale: int = 16,
):
super().__init__()
self.dim_model = embedding_size
self.rotary_emb = RotaryEmbedding(dim=embedding_size, distance_scale=distance_scale, dtype=dtype)
self.weight = torch.nn.parameter.Parameter(
torch.empty(vocab_size, embedding_size, dtype=dtype),
)
def forward(self, ids: torch.Tensor, ids_sub: torch.Tensor):
"""
Args:
ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens.
ids (:obj:`torch.Tensor` of shape ``(batch_size)``): Subscript of input sequence tokens.
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output.
""" # noqa: E501
embeds = F.embedding(ids, self.weight) / math.sqrt(self.dim_model)
return self.rotary_emb(embeds, ids_sub)
def projection(self, x: torch.Tensor, ext_table: Optional[torch.Tensor] = None):
"""
Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size.
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection
ext_table (:obj:`torch.Tensor` of shape ``(ext_table_size, dim_model)``): Ext vocab table.
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_size + ext_table_size)``: The projection output.
""" # noqa: E501
logits = F.linear(x / math.sqrt(self.dim_model), self.weight)
if ext_table is not None:
logits_ext = F.linear(x, ext_table)
logits = torch.cat([logits, logits_ext], dim=-1)
return logits

View File

@ -1,120 +0,0 @@
from typing import Optional
import torch
from .linear import Linear
class DenseGatedACT(torch.nn.Module):
def __init__(
self,
dim_in: int,
dim_ff: int,
dtype=torch.half,
activate_fn: str = "gelu",
scale: bool = True,
):
super().__init__()
self.w_0 = Linear(
dim_in=dim_in,
dim_out=dim_ff,
dtype=dtype,
scale=scale,
scale_before=False,
)
self.w_1 = Linear(
dim_in=dim_in,
dim_out=dim_ff,
dtype=dtype,
scale=scale,
scale_before=False,
)
if activate_fn == "gelu":
self.act = torch.nn.GELU()
elif activate_fn == "silu":
self.act = torch.nn.functional.silu
else:
raise NotImplementedError(f"{activate_fn} is not supported")
def forward(self, x: torch.Tensor):
"""Transform an input tensor from one feature space to another via a nonlinear operation
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations.
Return:
out (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_ff)``)
""" # noqa: E501
gate_score = self.act(self.w_0(x))
x = self.w_1(x)
x = gate_score * x
return x
class FeedForward(torch.nn.Module):
r"""FeedForward module
Args:
dim_in (int): input dimension.
dim_ff (int): middle dimension.
dim_out (int, optional): output dimension. Defaults to None, which means dim_in = dim_out.
dtype (optional): Defaults to torch.half.
init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.
init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02.
bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False.
activate_fn (str, optional): Defaults to `gated_gelu`.
dropout_p (int, optional): Defaults to 0.
""" # noqa: E501
def __init__(
self,
dim_model: int,
dim_ff: int,
activate_fn: str = "gelu",
dtype=torch.half,
dropout_p: Optional[float] = None,
scale: bool = True,
):
super().__init__()
self.w_in = DenseGatedACT(
dim_in=dim_model,
dim_ff=dim_ff,
activate_fn=activate_fn,
dtype=dtype,
scale=scale,
)
if dropout_p is not None:
self.dropout = torch.nn.Dropout(dropout_p)
else:
self.dropout = None
self.w_out = Linear(
dim_in=dim_ff,
dim_out=dim_model,
dtype=dtype,
scale=scale,
scale_before=False,
)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of feed-forward module.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of feed-forward module.
""" # noqa: E501
x = self.w_in(x)
if self.dropout is not None:
x = self.dropout(x)
x = self.w_out(x)
return x

View File

@ -1,37 +0,0 @@
import torch
@torch.jit.script
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
old_dtype = hidden.dtype
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
return hidden * weight
class LayerNorm(torch.nn.Module):
"""RMS LayerNorm"""
def __init__(
self,
dim_norm: int,
dtype: torch.dtype = torch.half,
eps: float = 1e-6,
init_var: float = 1.0,
):
super().__init__()
self.eps = eps
self.dim_norm = dim_norm
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var, dtype=dtype))
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``): Input tensor that need to be normalized.
Return:
:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output.
"""
assert x.size(-1) == self.dim_norm
return rms_layernorm(x, self.weight, self.eps)

View File

@ -1,44 +0,0 @@
import math
import torch
import torch.nn.functional as F
class Linear(torch.nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = True,
scale_before: bool = False,
):
super().__init__()
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.scale = scale
self.scale_before = scale_before
self.weight = torch.nn.parameter.Parameter(torch.empty((dim_out, dim_in), dtype=dtype))
torch.nn.init.normal_(self.weight, mean=init_mean, std=init_std)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if self.scale:
if self.scale_before:
x = x / math.sqrt(self.dim_in)
x = F.linear(x, self.weight)
else:
x = F.linear(x, self.weight)
x = x / math.sqrt(self.dim_in)
else:
x = F.linear(x, self.weight)
return x

View File

@ -1,286 +0,0 @@
import math
from typing import Tuple
from typing import Union
import torch
import torch.nn.functional as F
class SegmentPositionEmbedding(torch.nn.Module):
def __init__(
self,
num_heads: int,
num_segments: int = 1,
num_buckets: int = 32,
max_distance: int = 128,
bidirectional: bool = False,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
):
super().__init__()
self.num_heads = num_heads
self.num_buckets = num_buckets
self.max_distance = max_distance
self.bidirectional = bidirectional
self.num_segments = num_segments
self.relative_attention_bias = torch.nn.parameter.Parameter(
torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype)
)
def forward(
self,
key_pos: torch.Tensor,
query_pos: torch.Tensor,
key_segment: torch.Tensor,
query_segment: torch.Tensor,
):
with torch.no_grad():
batch = key_pos.size(0)
keylen = key_pos.size(1)
querylen = query_pos.size(1)
assert key_pos.size(0) == query_pos.size(0)
assert keylen == key_segment.size(1) and querylen == query_segment.size(1)
key_pos = key_pos.view(batch, -1, keylen)
query_pos = query_pos.view(batch, querylen, -1)
key_segment = key_segment.view(batch, -1, keylen)
query_segment = query_segment.view(batch, querylen, -1)
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠
# b*q*k
absolute_position_bucket = self._position_bucket(
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
relative_position_bucket = torch.where(
(key_segment == query_segment),
absolute_position_bucket[None, :, :],
relative_position_bucket,
)
# (batch, len_q, len_k)
# (batch, len_q, len_k, num_heads)
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
# (batch, num_heads, len_q, len_k)
embeds = embeds.permute(0, 3, 1, 2).contiguous()
return embeds
def _segment_relative_position_bucket(self, query_segment, key_segment):
return query_segment * self.num_segments + key_segment
def _position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.int32)
relative_postion_if_large = torch.min(
relative_postion_if_large,
torch.full_like(relative_postion_if_large, num_buckets - 1),
)
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
return relative_buckets
class BucketPositionBias(torch.nn.Module):
def __init__(
self,
num_heads: int,
num_buckets: int = 32,
num_segment_bucket: int = 32,
max_distance: int = 128,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
) -> None:
super().__init__()
self.num_heads = num_heads
self.num_buckets = num_buckets
self.num_segment_bucket = num_segment_bucket
self.max_distance = max_distance
self.relative_attention_bias = torch.nn.parameter.Parameter(
torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype)
)
def forward(
self,
query_pos: torch.Tensor, # (batch, len_q)
key_pos: torch.Tensor, # (batch, len_k)
rel_buckets: torch.Tensor, # (batch, len_q, len_k)
):
with torch.no_grad():
batch = key_pos.size(0)
keylen = key_pos.size(1)
querylen = query_pos.size(1)
assert key_pos.size(0) == query_pos.size(0)
assert rel_buckets.size(0) == batch and rel_buckets.size(1) == querylen and rel_buckets.size(2) == keylen
relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠
# b*q*k
inner_segment_bucket = self._position_bucket(
key_pos[..., None, :] - query_pos[..., :, None],
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
relative_position_bucket = torch.where(
rel_buckets == 0,
inner_segment_bucket,
relative_position_bucket,
)
# (batch, len_q, len_k)
# (batch, len_q, len_k, num_heads)
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
# (batch, num_heads, len_q, len_k)
embeds = embeds.permute(0, 3, 1, 2).contiguous()
return embeds
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
relative_buckets = 0
num_buckets //= 2
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
relative_position = torch.abs(relative_position)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.int32)
relative_postion_if_large = torch.min(
relative_postion_if_large,
torch.full_like(relative_postion_if_large, num_buckets - 1),
)
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
return relative_buckets
class RotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim,
base=10000,
distance_scale: Union[int, float] = 1,
dtype: torch.dtype = torch.half,
):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
inv_freq = inv_freq.to(dtype)
self.distance_scale = distance_scale
self.dtype = dtype
self.inv_freq = inv_freq
def forward(self, x: torch.Tensor, x_pos: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(..., dim)``): Inputs.
x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs.
"""
x_pos = x_pos * self.distance_scale
freqs = x_pos[..., None].to(self.dtype) * self.inv_freq[None, :] # (..., dim/2)
# the same implementation as sat
emb = torch.cat((freqs, freqs), dim=-1) # (..., dim)
emb_cos = emb.cos() # (..., dim)
emb_sin = emb.sin() # (..., dim)
rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim)
return x * emb_cos + rotate_x * emb_sin
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset):
if x.size(seq_dim) < cos.size(seq_dim):
cos = cos.narrow(seq_dim, offset, x.size(seq_dim))
sin = sin.narrow(seq_dim, offset, x.size(seq_dim))
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbeddingESM(torch.nn.Module):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""
def __init__(
self,
dim: int,
base: Union[int, float] = 10000,
distance_scale: Union[int, float] = 1,
dtype=torch.half,
persistent=True,
mixed_precision=False,
):
super().__init__()
self.base = base
self.distance_scale = distance_scale
self.dtype = dtype
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
if mixed_precision:
self.register_buffer("inv_freq", inv_freq, persistent=persistent)
else:
self.register_buffer("inv_freq", inv_freq.to(self.dtype), persistent=persistent)
self._seq_len_cached = -1
self._cos_cached = None
self._sin_cached = None
self.mixed_precision = mixed_precision
self.apply_rotary_pos_emb = apply_rotary_pos_emb
def _update_cos_sin_tables(self, x, seq_dim, offset):
seq_len = x.size(seq_dim) + offset
if seq_len > self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t * self.distance_scale, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
for i in range(x.dim() - 1):
if i != seq_dim:
emb = emb.unsqueeze_(i)
if self.mixed_precision:
self._cos_cached = emb.cos().to(self.dtype)
self._sin_cached = emb.sin().to(self.dtype)
else:
self._cos_cached = emb.cos()
self._sin_cached = emb.sin()
return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0) -> Tuple[torch.Tensor, torch.Tensor]:
seq_dim = (seq_dim + k.dim()) % k.dim()
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, offset)
return (
self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset),
self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset),
)

View File

@ -1,132 +0,0 @@
from typing import List
from typing import Optional
from typing import Tuple
import torch
from .blocks import TransformerBlock
from .layernorm import LayerNorm
class Encoder(torch.nn.Module):
"""Layers of encoder transformer blocks plus an final layernorm.
Args:
num_layers (int): number of layers.
dim_model (int): main dimension of modules in transformer blocks.
dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`.
num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`.
dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`.
dtype (optional): Defaults to torch.half.
eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-6.
dropout_p (float, optional): Defaults to 0.
"""
def __init__(
self,
num_layers: int,
dim_model: int,
dim_ff: int,
num_heads: int,
dim_head: int,
num_kv_heads: int = -1,
activate_fn: str = "gelu",
dtype: torch.dtype = torch.half,
eps: float = 1e-6,
dropout_p: Optional[float] = None,
scale: bool = True,
mask_modules: Optional[List[Tuple[bool, bool]]] = None,
use_flash_attn: bool = False,
):
super().__init__()
if num_kv_heads == -1:
num_kv_heads = num_heads
self.num_layers = num_layers
if mask_modules is not None:
assert len(mask_modules) == num_layers, "The total number of masks should equal to num_layers"
for mask_module in mask_modules:
assert len(mask_module) == 2, "For encoder, each mask should be (mask_att, mask_ffn)"
else:
mask_modules = [(False, False)] * num_layers
self.layers = torch.nn.ModuleList(
[
TransformerBlock(
dim_model=dim_model,
dim_ff=dim_ff,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
dim_head=dim_head,
activate_fn=activate_fn,
dtype=dtype,
eps=eps,
dropout_p=dropout_p,
scale=scale,
mask_att=mask_modules[ith][0],
mask_ffn=mask_modules[ith][1],
use_flash_attn=use_flash_attn,
)
for ith in range(num_layers)
]
)
self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_bias: torch.Tensor,
use_cache: bool = False,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
):
"""
Args:
hidden-states (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Input of encoder, might be the embedding of a batch of sequences.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc, seq_enc)``): Avoid invalid areas to participate in the calculation
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, seq_enc, seq_enc)``) Provides position information to attention mechanism.
Return:
:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``: The encoder output.
"""
if not use_cache:
for layer in self.layers:
hidden_states = layer(
hidden_states,
attention_mask,
position_bias,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
)
hidden_states = self.output_layernorm(hidden_states)
return hidden_states
else:
with torch.no_grad():
current_key_values = []
current_hidden_states = []
for i, module in enumerate(self.layers):
hidden_states = module(
hidden_states,
attention_mask,
position_bias,
past_key_value=past_key_values[i] if past_key_values else None,
use_cache=use_cache,
pos_bias_type=pos_bias_type,
length_mask=length_mask,
context_mask=context_mask,
)
if use_cache:
current_key_values.append(hidden_states[1])
current_hidden_states.append(hidden_states[0])
hidden_states = hidden_states[0]
hidden_states = self.output_layernorm(hidden_states)
if use_cache:
return hidden_states, current_key_values, current_hidden_states
else:
return hidden_states

View File

@ -1,62 +0,0 @@
import torch
def pad(orig_items, key, padding_value=0, padding_side="left"):
items = []
if isinstance(orig_items[0][key], list):
assert isinstance(orig_items[0][key][0], torch.Tensor)
for it in orig_items:
for tr in it[key]:
items.append({key: tr})
else:
assert isinstance(orig_items[0][key], torch.Tensor)
items = orig_items
batch_size = len(items)
shape = items[0][key].shape
dim = len(shape)
assert dim <= 3
max_length = max(item[key].shape[-1] for item in items)
min_length = min(item[key].shape[-1] for item in items)
dtype = items[0][key].dtype
if dim == 1:
return torch.cat([item[key] for item in items], dim=0)
elif dim == 2:
if max_length == min_length:
return torch.cat([item[key] for item in items], dim=0)
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
for i, item in enumerate(items):
if dim == 2:
if padding_side == "left":
tensor[i, -len(item[key][0]) :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0])] = item[key][0].clone()
elif dim == 3:
if padding_side == "left":
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
return tensor
def pad_raw(orig_items, max_length=1024, padding_value=0, padding_side="left"):
max_cols = max(tensor.size(1) for tensor in orig_items)
padded_arrays = []
for tensor in orig_items:
pad_cols = max_cols - tensor.size(1)
if padding_side == "left":
padded_tensor = torch.cat([torch.zeros(tensor.size(0), pad_cols), tensor], dim=1)
elif padding_side == "right":
padded_tensor = torch.cat([tensor, torch.zeros(tensor.size(0), pad_cols)], dim=1)
else:
raise ValueError("Invalid 'side' parameter. Must be 'left' or 'right'.")
padded_arrays.append(padded_tensor)
padded_tensor = torch.cat(padded_arrays, dim=0).to(dtype=torch.int32)
return padded_tensor

View File

@ -1,130 +0,0 @@
import functools
import json
import os
import shutil
import time
from typing import List
import bmtrain as bmt
import torch
from .log import logger
def rename_if_exists(file_path):
if not os.path.exists(file_path):
return
timestamp = time.strftime("%Y%m%d%H%M%S")
file_dir, file_name = os.path.split(file_path)
file_root, file_ext = os.path.splitext(file_name)
new_file_name = f"{file_root}_bak_{timestamp}{file_ext}"
new_file_path = os.path.join(file_dir, new_file_name)
try:
os.rename(file_path, new_file_path)
logger.info(f"File '{file_name}' already exists. Renamed to '{new_file_name}'")
except Exception as e:
logger.warn(
"rename file failed,file_path={file_path}, new_file_path={new_file_path},err={err}".format(
file_path=file_path, new_file_path=new_file_path, err=str(e)
)
)
def rename_if_exists_decorator(func):
@functools.wraps(func)
def wrapper(file_path, *args, **kwargs):
rename_if_exists(file_path)
return func(file_path, *args, **kwargs)
return wrapper
@rename_if_exists_decorator
def bmt_save(file_path: str, model: torch.nn.Module, export_files: List[str] = None):
bmt.save(model, file_path)
if export_files is not None:
export_files.append(file_path)
@rename_if_exists_decorator
def torch_save(file_path: str, obj: object, export_files: List[str] = None):
torch.save(obj, file_path)
if export_files is not None:
export_files.append(file_path)
@rename_if_exists_decorator
def json_save(file_path: str, obj: object, export_files: List[str] = None):
with open(file_path, "w") as data_f:
json.dump(obj, data_f)
if export_files is not None:
export_files.append(file_path)
def export(
model: torch.nn.Module, dataloader, optimizer: bmt.optim.AdamOffloadOptimizer, global_step, args, final_save=False
):
"""
一次 ckpt 保存
/{args.save}/
{save_name}-{global_step}.rank-0.opt
{save_name}-{global_step}.rank-n.opt
job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
config.json
vocabs.txt
{args.save_name}-{global_step}.pt
{args.save_name}-{global_step}.data
{args.save_name}-{global_step}.data.json
{args.save_name}-{global_step}.success
"""
export_model_dir = os.path.join(args.save, f"l_{global_step}")
os.makedirs(export_model_dir, exist_ok=True)
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
logger.info(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
export_files = []
# model checkpoint
bmt_save(
file_path=os.path.join(export_model_dir, base_file_name + ".pt"),
model=model,
export_files=export_files,
)
# opt is only used for continual pre-training, not the final opt
if not final_save:
grad_path = os.path.join(
args.save,
args.save_name + ("-%d.rank-%d.opt" % (global_step % (args.save_iters * 5), bmt.rank())),
)
torch.save(optimizer.state_dict(), grad_path)
logger.info(f"Successfully save grad file: {grad_path}")
all_states = dataloader.state_dict()
if bmt.rank() == 0:
# data checkpoint
# rank 0 writes the dataloader state
torch_save(
file_path=os.path.join(export_model_dir, base_file_name + ".data"),
obj=all_states,
export_files=export_files,
)
# data checkpoint json
# rank 0 writes the dataloader state into the json file
data_p_json = {k: v for k, v in all_states.items()}
for k in data_p_json:
data_p_json[k] = {k_of_v: data_p_json[k][k_of_v].tolist() for k_of_v in data_p_json[k]}
json_save(
file_path=os.path.join(export_model_dir, base_file_name + ".data.json"),
obj=data_p_json,
export_files=export_files,
)
# config 和 vocabs 和模型文件一起存储
model_cfg_path = os.path.join(export_model_dir, "config.json")
model_vocab_path = os.path.join(export_model_dir, "vocabs.txt")
export_files.extend([model_cfg_path, model_vocab_path])
shutil.copy(args.model_config, model_cfg_path)
shutil.copy(args.vocab, model_vocab_path)
logger.info(f"Successfully save model files! {export_files}")
del all_states
return export_model_dir

4
FM_9G/apps/__init__.py Normal file
View File

@ -0,0 +1,4 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright @2024, QiYuan Inc

View File

@ -0,0 +1,20 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = "<用户>"+data['input']+"<AI>"
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -0,0 +1,20 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = data['input']
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -0,0 +1,134 @@
[
{
"dataset_name": "humanevallike_clean_dedup",
"task_name": "humanevallike_clean_dedup",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 995339,
"ave_tokens_per_line": 100,
"total_tokens": 0.1
},
{
"dataset_name": "leetcode_pass_code_0125",
"task_name": "leetcode_pass_code_0125",
"abs_weight": 0.006,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 10724,
"ave_tokens_per_line": 200,
"total_tokens": 0.002
},
{
"dataset_name": "logiv2Annotate",
"task_name": "logiv2Annotate",
"abs_weight": 0.004,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 12566,
"ave_tokens_per_line": 512,
"total_tokens": 0.006
},
{
"dataset_name": "mmlu_enhance",
"task_name": "mmlu_enhance",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 169771,
"ave_tokens_per_line": 300,
"total_tokens": 0.05
},
{
"dataset_name": "mtbench_like",
"task_name": "mtbench_like",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 319080,
"ave_tokens_per_line": 500,
"total_tokens": 0.15
},
{
"dataset_name": "ultra_dataset_new",
"task_name": "ultra_dataset_new",
"abs_weight": 2.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 385045,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 2.0
},
{
"dataset_name": "sft_data_zh_wowru",
"task_name": "sft_data_zh_wowru",
"abs_weight": 1.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 1
},
{
"dataset_name": "math_data",
"task_name": "math_data",
"abs_weight": 0.003,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 0.005
},
{
"dataset_name": "t0",
"task_name": "t0",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 1650309,
"ave_tokens_per_line": 500.296266559615,
"total_tokens": 0.82
},
{
"dataset_name": "wikihow",
"task_name": "wikihow",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 180128,
"ave_tokens_per_line": 900.296266559615,
"total_tokens": 0.16
},
{
"dataset_name": "reclor",
"task_name": "reclor",
"abs_weight": 0.002,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 4174,
"ave_tokens_per_line": 700.296266559615,
"total_tokens": 0.003
},
{
"dataset_name": "logic_test_lx_0127",
"task_name": "logic_test_lx_0127",
"abs_weight": 0.001,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2800,
"ave_tokens_per_line": 200.96266559615,
"total_tokens": 0.0004
}
]

View File

@ -0,0 +1,28 @@
{
"vocab_size": 122753,
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"half_type": "bf16",
"use_flash_attn": true,
"flash_attn_mask_shape": "2d",
"dim_model": 2304,
"dim_ff": 5760,
"dim_head": 64,
"num_heads": 36,
"num_kv_heads": 36,
"num_layers": 40,
"activate_fn": "silu",
"init_std": 0.10,
"scale": true,
"scale_emb": 12,
"scale_depth": 1.4,
"dim_model_base": 256,
"model_type": "fm9g",
"architectures": [
"FM9GForCausalLM"
],
"qk_norm": false,
"tie_lm_head": true,
"ffn_gated": true
}

View File

@ -0,0 +1,548 @@
# coding=utf-8
# Copyright 2024 QiYuan Inc.
import inspect
import json
import math
import os
import re
import sys
import time
from collections import defaultdict
from itertools import chain
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from bmtrain import nccl
from bmtrain.global_var import config as bmt_config
sys.path.append("../../")
from fm9g.arguments import get_args
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
from fm9g.utils import exporter
from fm9g.utils import logger
from fm9g.utils.exporter import save_every_step_stats
from fm9g.utils.training_stats import num_non_embedding_parameters
from fm9g.utils.training_stats import num_parameters
def get_tokenizer(args):
from transformers import LlamaTokenizerFast
tokenizer = LlamaTokenizerFast(vocab_file=args.tokenizer_path)
return tokenizer
def get_model(args):
config = DragonflyConfig.from_json_file(args.model_config)
config.tp = 1 if args.tp_size != 1 else 0 # TODO
config.pose_prob = args.pose_prob
config.pose_scaling_factor = args.pose_scaling_factor
config.rope_scaling_type = args.rope_scaling_type
config.rope_scaling_factor = args.rope_scaling_factor
config.orig_max_length = args.orig_max_length
bmt.print_rank("model config: {}".format(config))
bmt.print_rank("bmt config: {}".format(bmt.config))
model = Dragonfly(config)
if args.load is not None:
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
exporter.load_model_ckpt(args, model)
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
return model
def get_optimizer(args, model):
scale_lr_group = []
normal_group = []
scale_lr_group_name, normal_group_name = [], []
for n, p in model.named_parameters():
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
scale_lr_group.append(p)
scale_lr_group_name.append(n)
else:
normal_group.append(p)
normal_group_name.append(n)
bmt.print_rank(scale_lr_group_name, normal_group_name)
param_groups = [
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
{"params": normal_group, "lr": args.lr},
]
if args.offload:
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
else:
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
if args.load is not None and args.load_grad:
exporter.load_optimizer_ckpt(args, optimizer)
bmt.print_rank("optimizer is loaded!")
return optimizer
def get_learning_rate_scheduler(args, optimizer):
from fm9g.training_utils.lr_scheduler import Cosine
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
end_iter = args.train_iters
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
warmup_iters = int(end_iter * args.warmup_iters)
else:
warmup_iters = int(args.warmup_iters)
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
drop_iters = int(end_iter * args.drop_iters)
else:
drop_iters = int(args.drop_iters)
if args.lr_scheduler == "cosine":
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
num_iter=args.start_step,
#lr_end_restart=args.lr_end_restart,
#resume_no_optimze=args.resume_no_optimze,
)
elif args.lr_scheduler == "warmupstabledrop":
lr_scheduler = WarmupStableDrop(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
drop_iter=drop_iters,
num_iter=args.start_step,
resume_no_optimze=args.resume_no_optimze,
)
return lr_scheduler
def setup_model_and_optimizer(args):
start = time.time()
tokenizer = get_tokenizer(args)
bmt.synchronize()
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
start = time.time()
model = get_model(args)
logger.info("load model in {:.2f}s".format(time.time() - start))
start = time.time()
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
return tokenizer, model, optimizer, lr_scheduler
def resume_training(args):
ckpts = sorted(
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
reverse=True,
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
)
# find newest job
ckpts = sorted(
ckpts,
reverse=True,
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
)
if len(ckpts) > 0:
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
args.load = ckpts[0]
# by default, do not load grad file
args.load_grad = False
args.start_step = 0
else:
# no ckpts, nothing we can do
os._exit(1)
def initialize():
args = get_args(pretrain=True)
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
if args.load is not None:
if args.only_load_model == 0:
if args.start_step == 0:
log_ckpt = exporter.load_log_ckpt(args)
if "iteration" in log_ckpt:
args.start_step = log_ckpt["iteration"]
else:
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
logger.info("Start from step {}".format(args.start_step))
elif args.only_load_model == 1:
logger.info("You load model ckpt, and choose to completely start the 0 step.")
else:
raise NotImplementedError
else:
logger.info("You do not load model")
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
bmt.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
def get_task_loss_and_token(loss, task_ids, task_num, targets):
# task_ids 可能有-1 来代表无效token
_task_num = task_num + 1
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
# gen masks
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
_loss_mask = torch.ne(targets, -100).to(torch.int32)
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
# calc loss and tokens
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
# return token-wise avg losses and tokens
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
class ChunkAve:
def __init__(self, chunk_size=100):
self.ave_list = []
self.chunk_size = chunk_size
def record(self, time):
self.ave_list.append(time)
self.ave_list = self.ave_list[-self.chunk_size :]
def get(self):
return sum(self.ave_list) / len(self.ave_list)
def pretrain(
args,
tokenizer,
model: Dragonfly,
optimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
):
ave_model_time = ChunkAve(chunk_size=100)
ave_iter_time = ChunkAve(chunk_size=100)
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
optim_manager = bmt.optim.OptimManager(
loss_scale=None,
loss_scale_steps=args.loss_scale_steps,
loss_scale_factor=2,
max_loss_scale=args.max_loss_scale,
min_loss_scale=args.min_loss_scale,
)
optim_manager.add_optimizer(optimizer, lr_scheduler)
start_step = args.start_step
if args.tensorboard is not None and bmt.rank() == 0:
import distutils.version # noqa: F401
from tensorboardX import SummaryWriter
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
if args.load is not None:
log_ckpt = exporter.load_log_ckpt(args)
else:
log_ckpt = {}
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
global_world_size = bmt.world_size()
bmt.print_rank("Begin preparing dataset")
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset = MixedIndexedDataset(
cfg_path=args.dataset,
cfg_json_str=None,
tokenizer=tokenizer,
max_length=args.max_length,
nthreads=args.dataloader_num_threads,
prefetch_slice=args.dataloader_prefetch,
weight_by_size=True,
)
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
dataloader = torch.utils.data.DataLoader(
batched_dataset,
batch_size=None,
collate_fn=lambda x: x,
num_workers=args.dataloader_num_workers,
prefetch_factor=args.dataloader_prefetch_factor,
)
else:
def dummy_generator():
while True:
yield None
mixed_indexed_dataset = dummy_generator()
dataloader = mixed_indexed_dataset
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
bmt.print_rank("Preparing dataset done.")
# inspect at init
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
try:
mem_usage, tim_usage = {}, {}
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
for iteration, data in enumerate(DataIterator, start=start_step + 1):
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
logits = model(
input=data["inputs"],
cu_seqlens=data["cu_seqlens"],
max_seqlen=data["max_seqlen"],
position_ids=data["position_ids"],
)
# chunk targets and task_ids
data["targets"] = (
data["targets"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["targets"].shape[0], -1)
)
data["task_ids"] = (
data["task_ids"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["task_ids"].shape[0], -1)
)
_target = data["targets"].view(-1)
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
_w = (_target != -100).int()
loss = non_reduced_loss.sum() / _w.sum().float()
global_loss = bmt.sum_loss(loss).item()
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
grad_accum_init_time = tim_usage["init"]
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
optim_manager.zero_grad()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
model_time = tim_usage["optim"] - grad_accum_init_time
ave_model_time.record(model_time)
else:
# dummy optim step
grad_norm = torch.Tensor([0.0]).cuda()
tim_usage["optim"] = tim_usage["backward"]
mem_usage["optim"] = mem_usage["backward"]
with torch.no_grad():
task_num = len(data["task_names"])
task_loss, task_token = get_task_loss_and_token(
non_reduced_loss, data["task_ids"], task_num, data["targets"]
)
task_loss_map: Dict[str, float] = {}
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
tot_task_token = gatherd_task_token_map.sum(dim=0)
ave_task_loss = sum_task_loss / tot_task_token
for i in range(task_num):
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
local_total_rate = torch.Tensor([data["lengths"].float().mean() / args.max_length]).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += (
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
)
bmt.print_rank(
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
)
last_before_log_time = tim_usage["before_log"]
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
iter_time = tim_usage["before_log"] - last_before_log_time
ave_iter_time.record(iter_time)
train_info = {
"time": iter_time,
"iteration": iteration,
"loss": global_loss,
"lr": lr_scheduler.current_lr,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
"grad_norm": grad_norm.item(),
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
"task_loss": task_loss_map,
"total_task_token": global_total_task_token,
}
global_token_pass_str = convert_to_k_and_b(global_token_pass)
bmt.print_rank(
(
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time:.2f} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
+ "{global_token_pass} | mem_usage {mem_usage} | "
).format(
iteration=iteration,
loss=global_loss,
lr=lr_scheduler.current_lr,
model_time=model_time,
iter_time=iter_time,
chunk_ave_time=ave_iter_time.get(),
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
grad_norm=grad_norm.item(),
global_token_pass=global_token_pass_str,
mem_usage=max([value for key, value in mem_usage.items()]),
)
)
bmt.print_rank(
"task_loss:\t| "
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
+ " |"
)
if iteration % 10 == 0:
bmt.print_rank(
"task_tokens (B):\t| "
+ " | ".join(
[
"{}: {:.4f}".format(task_name, task_token / 10**9)
for task_name, task_token in global_total_task_token.items()
]
)
+ " |"
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
if args.log_dir is not None and bmt.rank() == 0:
if args.save is not None:
save_every_step_stats(train_info, args.save)
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar("Loss/train", global_loss, iteration)
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
for task_name, loss in task_loss_map.items():
if not math.isnan(loss):
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
log_ckpt = {
"global_total_task_token": global_total_task_token,
"global_token_pass": global_token_pass,
"iteration": iteration,
}
if args.save is not None and iteration % args.save_iters == 0:
exporter.export(
model,
mixed_indexed_dataset,
tokenizer,
optimizer,
iteration,
args,
log_ckpt=log_ckpt,
final_save=False,
)
if iteration == args.train_iters and args.stop_when_end == 1:
break
except Exception as e:
print(f"train loop err: {e}")
raise e
finally:
pass
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
def convert_to_k_and_b(number):
if number >= 1e9: # 大于或等于10亿
b_number = number / 1e9
return f"{b_number:.2f}B"
elif number >= 1e6: # 大于或等于1百万
k_number = number / 1e6
return f"{k_number:.2f}M"
elif number >= 1e3:
k_number = number / 1e3
return f"{k_number:.2f}K"
else:
return str(number)
def main():
args = initialize()
bmt.synchronize()
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
bmt.print_rank("finish loading")
bmt.print_rank(
"Number of parameter {}, Number of non-e parameter {}".format(
num_parameters(model), num_non_embedding_parameters(model)
)
)
bmt.print_rank("args: {}".format(args))
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,234 @@
#!/bin/bash
#export OMP_NUM_THREADS=16
declare -A args # Declare an associative array to store arguments and values
args["model_unique"]="2b_0701"
args["resume_ckpt"]=""
args["config"]="2.4b"
args["flash"]="cuda"
args["batch_size"]="1"
args["max_length"]="4096"
args["save_iters"]="500"
args["train_iters"]="10"
args["dataset_config"]="fm9g_sft"
args["local"]="False"
args["dataloader"]="indexed"
args["save"]="True"
args["dataloader_num_threads"]=1
args["dataloader_prefetch"]=1
args["dataloader_prefetch_factor"]=1
args["dataloader_num_workers"]=1
args["lr"]="1e-5"
args["warmup_iters"]="20"
args["drop_iters"]="0.1"
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
args["load_grad"]="False"
args["grad_ckpt_num"]="160"
args["exp_group"]=""
args["ignore_cuda_oom"]="1"
args["tensorboard_all_tasks"]="0"
args["stop_when_end"]="0"
args["only_run_dataloader"]="0"
args["eps"]="1e-6"
args["inspect_iters"]="100"
args["strict_state_dict"]="1"
args["only_load_model"]="1"
args["lr_scheduler"]="cosine"
args["resume_no_optimze"]="0"
args["tp_size"]="1"
args["parallel_load_datastate"]="8"
args["async_save"]="False"
args["load_dataloader_ckpt"]="0"
args["drop_begin"]="-1"
args["drop_rate"]="0.5"
args["use_checkpoint"]="0"
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# 使用 Python 读取 JSON 文件并更新 Bash 字典
while read -r key value; do
args["$key"]="$value"
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
# 用cmd arg 再更新一次
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# Print the values of the arguments
echo "----------- CMD args ----------"
for key in "${!args[@]}"; do
echo "$key: ${args[$key]}"
done
echo "--------- END CMD args --------"
if [[ ${args["flash"]} == "triton" ]]; then
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
echo "triton flash"
fi
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
# GPUS_PER_NODE=1
echo "Using ${GPUS_PER_NODE} GPU each machine"
if [[ ${args["model_unique"]} == "" ]]; then
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
else
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
fi
echo "model_unique: "$MODEL_UNIQUE
# --------------- 运行参数 ---------------
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
OPTS+=" --batch-size ${args["batch_size"]}"
OPTS+=" --train-iters ${args["train_iters"]}"
OPTS+=" --save-iters ${args["save_iters"]}"
OPTS+=" --save-name fm9g_live_checkpoint"
OPTS+=" --max-length ${args["max_length"]}"
OPTS+=" --lr ${args["lr"]}"
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
OPTS+=" --drop-iters ${args["drop_iters"]}"
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
OPTS+=" --offload"
#OPTS+=" --vocab ./tokenizer/vocab.txt"
OPTS+=" --flash ${args["flash"]}"
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
OPTS+=" --eps ${args["eps"]}"
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
OPTS+=" --only_load_model ${args["only_load_model"]}"
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
OPTS+=" --weight-decay 0.1"
OPTS+=" --tp-size ${args["tp_size"]}"
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
OPTS+=" --drop_begin ${args["drop_begin"]}"
OPTS+=" --drop_rate ${args["drop_rate"]}"
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
if [[ ${args["load_grad"]} == "True" ]]; then
OPTS+=" --load-grad"
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
fi
if [[ ${args["async_save"]} == "True" ]]; then
OPTS+=" --async_save"
fi
if [[ ${args["dataloader"]} == "indexed" ]]; then
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
fi
# --------------- 写文件路径 ---------------
## checkpoint
if [[ ${args["save"]} == "True" ]]; then
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
else
echo "won't save model"
fi
## logs/local/logs 等价于 ./datalogs软链
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
if [[ ${args["local"]} == "True" ]]; then
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
else
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
echo "Platform config:"${PLATFORM_CONFIG_PATH}
fi
## checkpoint兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint启动会比较快
if [ "${args["resume_ckpt"]}" != "" ]; then
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
else
echo "No checkpoint to load"
fi
filename="pretrain_dragonfly"
if [[ ${args["local"]} == "True" ]]; then
PRETRAIN_ENTRY="$filename.py"
else
PRETRAIN_ENTRY="$filename.py"
fi
GPUS_PER_NODE=1
NNODES=1
RANK=0
MASTER_ENDPOINT=g3006
MASTER_PORT=23456
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
$CMD

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -0,0 +1,9 @@
{
"pretrain": {
"train_iters": 1000000000,
"batch_size": 1,
"max_length": 4096,
"n_gpus": 8,
"lr": 0.01
}
}

View File

@ -0,0 +1,20 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = "<用户>"+data['input']+"<AI>"
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -0,0 +1,20 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = data['input']
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -0,0 +1,134 @@
[
{
"dataset_name": "humanevallike_clean_dedup",
"task_name": "humanevallike_clean_dedup",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 995339,
"ave_tokens_per_line": 100,
"total_tokens": 0.1
},
{
"dataset_name": "leetcode_pass_code_0125",
"task_name": "leetcode_pass_code_0125",
"abs_weight": 0.006,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 10724,
"ave_tokens_per_line": 200,
"total_tokens": 0.002
},
{
"dataset_name": "logiv2Annotate",
"task_name": "logiv2Annotate",
"abs_weight": 0.004,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 12566,
"ave_tokens_per_line": 512,
"total_tokens": 0.006
},
{
"dataset_name": "mmlu_enhance",
"task_name": "mmlu_enhance",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 169771,
"ave_tokens_per_line": 300,
"total_tokens": 0.05
},
{
"dataset_name": "mtbench_like",
"task_name": "mtbench_like",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 319080,
"ave_tokens_per_line": 500,
"total_tokens": 0.15
},
{
"dataset_name": "ultra_dataset_new",
"task_name": "ultra_dataset_new",
"abs_weight": 2.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 385045,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 2.0
},
{
"dataset_name": "sft_data_zh_wowru",
"task_name": "sft_data_zh_wowru",
"abs_weight": 1.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 1
},
{
"dataset_name": "math_data",
"task_name": "math_data",
"abs_weight": 0.003,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 0.005
},
{
"dataset_name": "t0",
"task_name": "t0",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 1650309,
"ave_tokens_per_line": 500.296266559615,
"total_tokens": 0.82
},
{
"dataset_name": "wikihow",
"task_name": "wikihow",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 180128,
"ave_tokens_per_line": 900.296266559615,
"total_tokens": 0.16
},
{
"dataset_name": "reclor",
"task_name": "reclor",
"abs_weight": 0.002,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 4174,
"ave_tokens_per_line": 700.296266559615,
"total_tokens": 0.003
},
{
"dataset_name": "logic_test_lx_0127",
"task_name": "logic_test_lx_0127",
"abs_weight": 0.001,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2800,
"ave_tokens_per_line": 200.96266559615,
"total_tokens": 0.0004
}
]

View File

@ -0,0 +1,27 @@
{
"vocab_size": 119696,
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"half_type": "bf16",
"use_flash_attn": true,
"flash_attn_mask_shape": "2d",
"dim_model": 4096,
"dim_ff": 14336,
"dim_head": 128,
"num_heads": 32,
"num_kv_heads": 32,
"num_layers": 32,
"activate_fn": "silu",
"init_std": 0.10,
"scale": false,
"scale_emb": 12,
"scale_depth": -1,
"model_type": "fm9g",
"architectures": [
"FM9GForCausalLM"
],
"qk_norm": false,
"tie_lm_head": false,
"ffn_gated": true
}

View File

@ -0,0 +1,568 @@
# coding=utf-8
# Copyright 2022 ModelBest Inc.
import inspect
import json
import math
import os
import re
import sys
import time
from collections import defaultdict
from itertools import chain
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from bmtrain import nccl
from bmtrain.global_var import config as bmt_config
sys.path.append("../../")
from fm9g.arguments import get_args
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import CudaPrefetcher
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import MixedIndexedDataset
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import UnpadBatchedMixedDataset
from fm9g.utils import exporter
from fm9g.utils import logger
from fm9g.utils.exporter import save_every_step_stats
from fm9g.utils.training_stats import num_non_embedding_parameters
from fm9g.utils.training_stats import num_parameters
def get_tokenizer(args):
from fm9g.tokenizer import FM9GTokenizer
tokenizer = FM9GTokenizer(path=args.vocab)
return tokenizer
def get_model(args):
config = DragonflyConfig.from_json_file(args.model_config)
config.tp = 1 if args.tp_size != 1 else 0 # TODO
config.pose_prob = args.pose_prob
config.pose_scaling_factor = args.pose_scaling_factor
config.rope_scaling_type = args.rope_scaling_type
config.rope_scaling_factor = args.rope_scaling_factor
config.orig_max_length = args.orig_max_length
config.use_checkpoint = True if args.use_checkpoint == 1 else False
bmt.print_rank("model config: {}".format(config))
bmt.print_rank("bmt config: {}".format(bmt.config))
model = Dragonfly(config)
if args.load is not None:
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
exporter.load_model_ckpt(args, model)
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
return model
def get_optimizer(args, model):
scale_lr_group = []
normal_group = []
scale_lr_group_name, normal_group_name = [], []
for n, p in model.named_parameters():
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
scale_lr_group.append(p)
scale_lr_group_name.append(n)
else:
normal_group.append(p)
normal_group_name.append(n)
bmt.print_rank(scale_lr_group_name, normal_group_name)
param_groups = [
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
{"params": normal_group, "lr": args.lr},
]
if args.offload:
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
else:
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
if args.load is not None and args.load_grad:
exporter.load_optimizer_ckpt(args, optimizer)
bmt.print_rank("optimizer is loaded!")
return optimizer
def get_learning_rate_scheduler(args, optimizer):
from fm9g.training_utils.lr_scheduler import Cosine
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
from fm9g.training_utils.lr_scheduler import WarmupStableExp
end_iter = args.train_iters
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
warmup_iters = int(end_iter * args.warmup_iters)
else:
warmup_iters = int(args.warmup_iters)
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
drop_iters = int(end_iter * args.drop_iters)
else:
drop_iters = int(args.drop_iters)
if args.lr_scheduler == "cosine":
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
num_iter=args.start_step,
)
# lr_end_restart=args.lr_end_restart,
# resume_no_optimze=args.resume_no_optimze,
#)
elif args.lr_scheduler == "warmupstabledrop":
lr_scheduler = WarmupStableDrop(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
drop_iter=drop_iters,
num_iter=args.start_step,
resume_no_optimze=args.resume_no_optimze,
)
elif args.lr_scheduler == "warmupstableexp":
lr_scheduler = WarmupStableExp(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
drop_begin=args.drop_begin, # 原来是lr_decay_iter
drop_iter=drop_iters,
drop_rate=args.drop_rate,
num_iter=args.start_step,
resume_no_optimze=args.resume_no_optimze,
)
return lr_scheduler
def setup_model_and_optimizer(args):
start = time.time()
tokenizer = get_tokenizer(args)
bmt.synchronize()
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
start = time.time()
model = get_model(args)
logger.info("load model in {:.2f}s".format(time.time() - start))
start = time.time()
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
return tokenizer, model, optimizer, lr_scheduler
def resume_training(args):
ckpts = sorted(
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
reverse=True,
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
)
# find newest job
ckpts = sorted(
ckpts,
reverse=True,
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
)
if len(ckpts) > 0:
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
args.load = ckpts[0]
# by default, do not load grad file
args.load_grad = False
args.start_step = 0
else:
# no ckpts, nothing we can do
os._exit(1)
def initialize():
args = get_args(pretrain=True)
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
if args.load is not None:
if args.only_load_model == 0:
if args.start_step == 0:
log_ckpt = exporter.load_log_ckpt(args)
if "iteration" in log_ckpt:
args.start_step = log_ckpt["iteration"]
else:
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
logger.info("Start from step {}".format(args.start_step))
elif args.only_load_model == 1:
logger.info("You load model ckpt, and choose to completely start the 0 step.")
else:
raise NotImplementedError
else:
logger.info("You do not load model")
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
bmt.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
def get_task_loss_and_token(loss, task_ids, task_num, targets):
# task_ids 可能有-1 来代表无效token
_task_num = task_num + 1
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
# gen masks
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
_loss_mask = torch.ne(targets, -100).to(torch.int32)
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
# calc loss and tokens
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
# return token-wise avg losses and tokens
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
class ChunkAve:
def __init__(self, chunk_size=100):
self.ave_list = []
self.chunk_size = chunk_size
def record(self, time):
self.ave_list.append(time)
self.ave_list = self.ave_list[-self.chunk_size :]
def get(self):
return sum(self.ave_list) / len(self.ave_list)
def pretrain(
args,
tokenizer,
model: Dragonfly,
optimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
):
ave_model_time = ChunkAve(chunk_size=100)
ave_iter_time = ChunkAve(chunk_size=100)
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
optim_manager = bmt.optim.OptimManager(
loss_scale=bmt.world_size(),
loss_scale_steps=args.loss_scale_steps,
loss_scale_factor=2,
max_loss_scale=bmt.world_size(),
min_loss_scale=bmt.world_size(),
)
optim_manager.add_optimizer(optimizer, lr_scheduler)
start_step = args.start_step
if args.tensorboard is not None and bmt.rank() == 0:
import distutils.version # noqa: F401
from tensorboardX import SummaryWriter
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
if args.load is not None:
log_ckpt = exporter.load_log_ckpt(args)
else:
log_ckpt = {}
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
global_world_size = bmt.world_size()
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset = MixedIndexedDataset(
cfg_path=args.dataset,
cfg_json_str=None,
tokenizer=tokenizer,
max_length=args.max_length,
nthreads=args.dataloader_num_threads,
prefetch_slice=args.dataloader_prefetch,
weight_by_size=True,
)
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
dataloader = torch.utils.data.DataLoader(
batched_dataset,
batch_size=None,
collate_fn=lambda x: x,
num_workers=args.dataloader_num_workers,
prefetch_factor=args.dataloader_prefetch_factor,
)
else:
def dummy_generator():
while True:
yield None
mixed_indexed_dataset = dummy_generator()
dataloader = mixed_indexed_dataset
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
bmt.print_rank("Preparing dataset done.")
# inspect at init
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
try:
mem_usage, tim_usage = {}, {}
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
for iteration, data in enumerate(DataIterator, start=start_step + 1):
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
logits = model(
input=data["inputs"],
cu_seqlens=data["cu_seqlens"],
max_seqlen=data["max_seqlen"],
position_ids=data["position_ids"],
)
#print("logits: ", logits)
# chunk targets and task_ids
data["targets"] = (
data["targets"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["targets"].shape[0], -1)
)
data["task_ids"] = (
data["task_ids"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["task_ids"].shape[0], -1)
)
_target = data["targets"].view(-1)
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
_w = (_target != -100).int()
loss = non_reduced_loss.sum() / _w.sum().float()
global_loss = bmt.sum_loss(loss).item()
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
grad_accum_init_time = tim_usage["init"]
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
optim_manager.zero_grad()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
model_time = tim_usage["optim"] - grad_accum_init_time
ave_model_time.record(model_time)
else:
# dummy optim step
grad_norm = torch.Tensor([0.0]).cuda()
tim_usage["optim"] = tim_usage["backward"]
mem_usage["optim"] = mem_usage["backward"]
model_time = tim_usage["optim"] - tim_usage['init']
with torch.no_grad():
task_num = len(data["task_names"])
task_loss, task_token = get_task_loss_and_token(
non_reduced_loss, data["task_ids"], task_num, data["targets"]
)
task_loss_map: Dict[str, float] = {}
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
tot_task_token = gatherd_task_token_map.sum(dim=0)
ave_task_loss = sum_task_loss / tot_task_token
for i in range(task_num):
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
local_total_rate = torch.Tensor(
[data["lengths"].float().mean() / (args.max_length * args.batch_size)]
).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += (
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
)
bmt.print_rank(
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
)
last_before_log_time = tim_usage["before_log"]
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
iter_time = tim_usage["before_log"] - last_before_log_time
ave_iter_time.record(iter_time)
train_info = {
"time": iter_time,
"iteration": iteration,
"loss": global_loss,
"lr": lr_scheduler.current_lr,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
"grad_norm": grad_norm.item(),
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
"task_loss": task_loss_map,
"total_task_token": global_total_task_token,
}
global_token_pass_str = convert_to_k_and_b(global_token_pass)
time_report_str = "{model_time:.2f}={forward_time:.2f}+{backward_time:.2f}+{optim_time:.2f}".format(model_time=model_time, forward_time=tim_usage['forward']-tim_usage['init'], backward_time=tim_usage['backward']-tim_usage['forward'], optim_time=tim_usage['optim'] - tim_usage['backward'])
bmt.print_rank(
(
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
+ "{global_token_pass} | mem_usage {mem_usage} | "
).format(
iteration=iteration,
loss=global_loss,
lr=lr_scheduler.current_lr,
model_time=time_report_str,
iter_time=iter_time,
chunk_ave_time=ave_iter_time.get(),
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
grad_norm=grad_norm.item(),
global_token_pass=global_token_pass_str,
mem_usage=max([value for key, value in mem_usage.items()]),
)
)
bmt.print_rank(
"task_loss:\t| "
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
+ " |"
)
if iteration % 10 == 0:
bmt.print_rank(
"task_tokens (B):\t| "
+ " | ".join(
[
"{}: {:.4f}".format(task_name, task_token / 10**9)
for task_name, task_token in global_total_task_token.items()
]
)
+ " |"
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
if args.log_dir is not None and bmt.rank() == 0:
if args.save is not None:
save_every_step_stats(train_info, args.save)
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar("Loss/train", global_loss, iteration)
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
for task_name, loss in task_loss_map.items():
if not math.isnan(loss):
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
log_ckpt = {
"global_total_task_token": global_total_task_token,
"global_token_pass": global_token_pass,
"iteration": iteration,
}
if args.save is not None and iteration % args.save_iters == 0:
exporter.export(
model,
mixed_indexed_dataset,
tokenizer,
optimizer,
iteration,
args,
log_ckpt=log_ckpt,
final_save=False,
async_save=args.async_save,
)
if iteration == args.train_iters and args.stop_when_end == 1:
break
except Exception as e:
print(f"train loop err: {e}")
raise e
finally:
pass
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
def convert_to_k_and_b(number):
if number >= 1e9: # 大于或等于10亿
b_number = number / 1e9
return f"{b_number:.2f}B"
elif number >= 1e6: # 大于或等于1百万
k_number = number / 1e6
return f"{k_number:.2f}M"
elif number >= 1e3:
k_number = number / 1e3
return f"{k_number:.2f}K"
else:
return str(number)
def main():
args = initialize()
bmt.synchronize()
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
bmt.print_rank("finish loading")
bmt.print_rank(
"Number of parameter {}, Number of non-e parameter {}".format(
num_parameters(model), num_non_embedding_parameters(model)
)
)
bmt.print_rank("args: {}".format(args))
print("begining training")
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,234 @@
#!/bin/bash
#export OMP_NUM_THREADS=16
declare -A args # Declare an associative array to store arguments and values
args["model_unique"]="8b_0702"
args["resume_ckpt"]=""
args["config"]="8b"
args["flash"]="cuda"
args["batch_size"]="1"
args["max_length"]="4096"
args["save_iters"]="500"
args["train_iters"]="10"
args["dataset_config"]="fm9g_sft"
args["local"]="False"
args["dataloader"]="indexed"
args["save"]="True"
args["dataloader_num_threads"]=1
args["dataloader_prefetch"]=2
args["dataloader_prefetch_factor"]=32
args["dataloader_num_workers"]=2
args["lr"]="1e-5"
args["warmup_iters"]="20"
args["drop_iters"]="0.1"
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
args["load_grad"]="False"
args["grad_ckpt_num"]="160"
args["exp_group"]=""
args["ignore_cuda_oom"]="1"
args["tensorboard_all_tasks"]="0"
args["stop_when_end"]="0"
args["only_run_dataloader"]="0"
args["eps"]="1e-6"
args["inspect_iters"]="100"
args["strict_state_dict"]="1"
args["only_load_model"]="1"
args["lr_scheduler"]="cosine"
args["resume_no_optimze"]="0"
args["tp_size"]="1"
args["parallel_load_datastate"]="16"
args["async_save"]="False"
args["load_dataloader_ckpt"]="0"
args["drop_begin"]="-1"
args["drop_rate"]="0.5"
args["use_checkpoint"]="1"
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# 使用 Python 读取 JSON 文件并更新 Bash 字典
while read -r key value; do
args["$key"]="$value"
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
# 用cmd arg 再更新一次
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# Print the values of the arguments
echo "----------- CMD args ----------"
for key in "${!args[@]}"; do
echo "$key: ${args[$key]}"
done
echo "--------- END CMD args --------"
if [[ ${args["flash"]} == "triton" ]]; then
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
echo "triton flash"
fi
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
# GPUS_PER_NODE=1
echo "Using ${GPUS_PER_NODE} GPU each machine"
if [[ ${args["model_unique"]} == "" ]]; then
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
else
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
fi
echo "model_unique: "$MODEL_UNIQUE
# --------------- 运行参数 ---------------
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
OPTS+=" --batch-size ${args["batch_size"]}"
OPTS+=" --train-iters ${args["train_iters"]}"
OPTS+=" --save-iters ${args["save_iters"]}"
OPTS+=" --save-name fm9g_live_checkpoint"
OPTS+=" --max-length ${args["max_length"]}"
OPTS+=" --lr ${args["lr"]}"
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
OPTS+=" --drop-iters ${args["drop_iters"]}"
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
OPTS+=" --offload"
OPTS+=" --vocab ./tokenizer/vocab.txt"
OPTS+=" --flash ${args["flash"]}"
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
OPTS+=" --eps ${args["eps"]}"
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
OPTS+=" --only_load_model ${args["only_load_model"]}"
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
OPTS+=" --weight-decay 0.1"
OPTS+=" --tp-size ${args["tp_size"]}"
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
OPTS+=" --drop_begin ${args["drop_begin"]}"
OPTS+=" --drop_rate ${args["drop_rate"]}"
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
if [[ ${args["load_grad"]} == "True" ]]; then
OPTS+=" --load-grad"
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
fi
if [[ ${args["async_save"]} == "True" ]]; then
OPTS+=" --async_save"
fi
if [[ ${args["dataloader"]} == "indexed" ]]; then
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
fi
# --------------- 写文件路径 ---------------
## checkpoint
if [[ ${args["save"]} == "True" ]]; then
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
else
echo "won't save model"
fi
## logs/local/logs 等价于 ./datalogs软链
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
if [[ ${args["local"]} == "True" ]]; then
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
else
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
echo "Platform config:"${PLATFORM_CONFIG_PATH}
fi
## checkpoint兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint启动会比较快
if [ "${args["resume_ckpt"]}" != "" ]; then
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
else
echo "No checkpoint to load"
fi
filename="pretrain_dragonfly"
if [[ ${args["local"]} == "True" ]]; then
PRETRAIN_ENTRY="$filename.py"
else
PRETRAIN_ENTRY="$filename.py"
fi
GPUS_PER_NODE=8
NNODES=1
RANK=0
MASTER_ENDPOINT=g3006
MASTER_PORT=12345
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
$CMD

View File

@ -119687,8 +119687,8 @@
"𠳐"
"𥻗"
"𬉼"
"<pad_0>"
"<pad_1>"
"<|im_start|>"
"<|im_end|>"
"<pad_2>"
"<pad_3>"
"<pad_4>"

View File

@ -0,0 +1,9 @@
{
"pretrain": {
"train_iters": 20000,
"batch_size": 1,
"max_length": 4096,
"n_gpus": 8,
"lr": 1e-5
}
}

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
@ -7,6 +22,8 @@ def add_model_config_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("model", "model configuration")
group.add_argument("--model-config", type=str, help="model configuration file")
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
# group.add_argument("--qk_norm", action="store_true", default=False, help="qk layernorm")
return parser
@ -31,6 +48,13 @@ def add_training_args(parser: argparse.ArgumentParser):
help="Load the gradient states",
)
group.add_argument(
"--grad-ckpt-num",
type=int,
default=0,
help="grad file num (only work when --load-grad from files less than world-size )",
)
group.add_argument(
"--load-start-step",
action="store_true",
@ -66,7 +90,9 @@ def add_training_args(parser: argparse.ArgumentParser):
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
group.add_argument("--num-micro-batches", type=int, default=16)
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
group.add_argument("--grad-accum", type=int, default=1, help="gradient accum steps")
group.add_argument(
"--train-iters",
type=int,
@ -74,11 +100,14 @@ def add_training_args(parser: argparse.ArgumentParser):
help="total number of iterations to train over all training runs",
)
group.add_argument("--max-length", type=int, default=512, help="max length of input")
group.add_argument("--min-length", type=int, default=None, help="only for speed test")
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
# Learning rate.
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
group.add_argument("--lr_scheduler", type=str, default="cosine", help=" learning rate scheduler")
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
@ -92,21 +121,85 @@ def add_training_args(parser: argparse.ArgumentParser):
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
)
group.add_argument(
"--lr-decay-style",
type=str,
default="noam",
choices=["constant", "linear", "cosine", "exponential", "noam"],
help="learning rate decay function",
"--drop-iters",
type=float,
default=0.01,
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
)
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
group.add_argument("--tp", default=1, type=int, help="whether we use tensor parallelism")
group.add_argument("--use-jfs-data", action="store_true", help="whether we use juicefs dataset")
group.add_argument("--tp-size", default=1, type=int)
group.add_argument("--pp-size", default=1, type=int)
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
group.add_argument("--gradient-accumulation-steps", type=int, default=1, help="gradient accumulation steps")
group.add_argument("--dataloader_num_threads", default=3, type=int, help="Only useful in indexed dataest.")
group.add_argument("--dataloader_prefetch", default=200, type=int, help="Only useful in indexed dataest.")
group.add_argument("--dataloader_num_workers", default=4, type=int, help="Only useful in indexed dataest.")
group.add_argument("--dataloader_prefetch_factor", default=50, type=int, help="Only useful in indexed dataest.")
group.add_argument(
"--dataloader",
default="indexed",
type=str,
help="dataloader type, 'indexed' for indexed dataset, 'normal' for normal dataset",
)
group.add_argument("--stop_when_end", default=0, type=int, help="Whether to stop training when we reach end_iter")
group.add_argument(
"--data_len_threshold",
default=512,
type=int,
help="If the average length of a sequence is less than this int, mean the sample is biased. ",
)
group.add_argument(
"--only_run_dataloader", default=0, type=int, help="Whether to only run dataloader to check data. "
)
group.add_argument(
"--only_load_model", default=0, type=int, help="Whether to only load a model ckpt, without anything else."
)
group.add_argument(
"--load_dataloader_ckpt", default=1, type=int, help="Whether to only load a model ckpt, without anything else."
)
group.add_argument(
"--resume_no_optimze",
default=0,
type=int,
help="The number of steps that does not add optimization after resume",
)
group.add_argument(
"--parallel_load_datastate",
default=256,
type=int,
help="The number of parallel workers to load dataset state",
)
group.add_argument(
"--async_save",
action="store_true",
help="whether to save artifacts asynchronously",
)
group.add_argument(
"--drop_begin",
default=-1,
type=int,
help="The number of steps that starts to drop lr"
)
group.add_argument(
"--drop_rate",
default=0.5,
type=float,
help="The number rate"
)
group.add_argument(
"--use_checkpoint",
default=1,
type=int,
help="Whether to use checkpointing."
)
return parser
@ -133,6 +226,17 @@ def add_pretrain_args(parser: argparse.ArgumentParser):
return parser
def add_tokenizer_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("tokenizer", "tokenizer configurations")
group.add_argument(
"--tokenizer_path",
type=str,
default="",
help="tokenizer_path",
)
return parser
def add_finetune_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("finetune", "finetune configurations")
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
@ -204,14 +308,53 @@ def add_feedback_learning_args(parser: argparse.ArgumentParser):
return parser
def add_delta_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("LoRA","LoRA configurations")
group.add_argument("--delta-type", type=str, default=None, help="delta-tuning-type")
group.add_argument("--lora-r", type=int, default=8, help="lora-rank")
group.add_argument("--lora-alpha", type=int, default=8, help="lora-alpha")
group.add_argument("--lora-dropout", type=float, default=0.0, help="lora-dropout")
group.add_argument("--lora-layer", nargs='+', default=['project_q','project_k'], help="lora-layer")
group.add_argument("--save-origin-model", action="store_true", default=False)
def add_model_change_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("model_change", "model change during pretraining")
group.add_argument("--strict_state_dict", type=int, default=1, help="strict_state_dict")
##
return parser
def add_log_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("log", "log configurations")
group.add_argument("--tensorboard_all_tasks", type=int, default=0, help="log")
return parser
def add_error_handle_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("error_handle", "error_handle configurations")
group.add_argument(
"--ignore_cuda_oom", type=int, default=1, help="continue training by ingore the batch that causes oom"
)
return parser
def add_runtime_eval_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("runtime eval args", "runtime evaluation by submitting a job")
group.add_argument(
"--runtime_eval",
action="store_true",
help="whether to use runtime_eval. Only if this is set to True, the following variables will be useful",
)
group.add_argument("--eval_jeeves_auth", type=str, default="", help="auth, press f12 on jeeves platform to get")
group.add_argument("--eval_project_id", type=str, default=None, help="project id")
group.add_argument("--eval_run_cmd", type=str, default="", help="cmd for eval")
group.add_argument(
"--eval_git_path",
type=str,
default="git@git.in.zhihu.com:luca/llm-bench.git",
help="git path of evaluation code",
)
group.add_argument("--eval_git_branch", type=str, default="master", help="git branch of evaluation code")
group.add_argument("--eval_node_num", type=int, default=1, help="using 1 node to evaluate")
group.add_argument("--eval_gpu_num", type=int, default=1, help="using 1 gpu per node to evaluate")
group.add_argument("--eval_tasks_config", type=str, default="", help="evaluate tasks' config")
group.add_argument("--eval_model_backend", default="torch", type=str, help="model_backend")
group.add_argument(
"--eval_at_start", action="store_true", help="whether to eval at the first epoch, default to false"
)
return parser
@ -222,6 +365,30 @@ def add_reward_args(parser: argparse.ArgumentParser):
return parser
def add_long_context_extend_args(parser: argparse.ArgumentParser):
"""long context extending arguments."""
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
group.add_argument("--pose-prob", default=0.0, type=float, help="Sample-level PoSE probability")
group.add_argument(
"--pose-scaling-factor",
default=1.0,
type=float,
help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor",
)
group.add_argument(
"--rope-scaling-type",
default="",
type=str,
choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],
help="Context scaling type",
)
group.add_argument("--rope-scaling-factor", default=1, type=int, help="Context scaling factor")
group.add_argument(
"--orig-max-length", default=8192, type=int, help="Original context length before context extending"
)
return parser
def get_args(
pretrain: bool = False,
finetune: bool = False,
@ -235,9 +402,14 @@ def get_args(
parser = add_training_args(parser)
if pretrain:
parser = add_pretrain_args(parser)
parser = add_runtime_eval_args(parser)
parser = add_tokenizer_args(parser)
parser = add_log_args(parser)
parser = add_error_handle_args(parser)
parser = add_model_change_args(parser)
if finetune:
parser = add_finetune_args(parser)
parser = add_delta_args(parser)
if rhlf:
parser = add_rhlf_args(parser)
if simple_rlhf:
@ -246,6 +418,7 @@ def get_args(
parser = add_feedback_learning_args(parser)
if reward:
parser = add_reward_args(parser)
parser = add_long_context_extend_args(parser)
args = parser.parse_args()

View File

@ -4,7 +4,8 @@ from .distributed_dataset import SimpleDataset
from .indexed_dataset import IndexedDataset
from .indexed_dataset import IndexedDatasetBuilder
from .indexed_dataset import PrefetchDecodeDataset
from .list_dataset import ListDataset
# from .list_dataset import ListDataset
from .utils import compact_dataset
from .utils import CudaPrefetcher
from .utils import mask_dataset

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
import io
import json
@ -281,7 +296,6 @@ class DistributedDataset:
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
old_len = len(self._file_info)
if old_len > len(info):
raise RuntimeError("Dataset meta file: changed unexpectly")
@ -443,7 +457,11 @@ class DistributedDataset:
with torch.no_grad():
if self._world_size > 1:
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
max_unused_blocks = bmt.distributed.all_reduce(gpu_num_unused_block, op="max").cpu().item()
max_unused_blocks = (
bmt.distributed.all_reduce(gpu_num_unused_block, op="max", comm=bmt.config["tp_zero_comm"])
.cpu()
.item()
)
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
@ -452,9 +470,15 @@ class DistributedDataset:
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
dtype=torch.long,
).cuda()
global_states = bmt.distributed.all_gather(gpu_states).cpu() # (world_size, max_unused_blocks)
global_offset = bmt.distributed.all_gather(gpu_offset).cpu() # (world_size, max_unused_blocks)
global_block = bmt.distributed.all_gather(gpu_block).cpu() # (world_size, 4)
global_states = bmt.distributed.all_gather(
gpu_states, comm=bmt.config["tp_zero_comm"]
).cpu() # (world_size, max_unused_blocks)
global_offset = bmt.distributed.all_gather(
gpu_offset, comm=bmt.config["tp_zero_comm"]
).cpu() # (world_size, max_unused_blocks)
global_block = bmt.distributed.all_gather(
gpu_block, comm=bmt.config["tp_zero_comm"]
).cpu() # (world_size, 4)
return {"states": global_states, "offset": global_offset, "block": global_block}
else:
return {

View File

@ -1,13 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
"""
使用 IndexedDataset 前需按指定格式构建或者转换已有数据集
数据集文件结构
- <dataset name>
- data.jsonl # jsonl 格式的数据,每一行一条样本
- index # 记录每一行 json 数据的起始 byte-offset
从头构建直接使用 IndexedDatasetBuilder 这个 context manager
>>> with IndexedDatasetBuilder("swear", overwrite=True) as builder:
>>> for data in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
>>> builder.put(data)
转换
fm9g distributed_dataset 转换使用 `fm9g.dataset.tools.distributed_to_indexed`
$ python -m fm9g.dataset.tools.distributed_to_indexed -i <原数据集文件夹> -o <新数据集文件夹>
已有 jsonl 数据使用 `fm9g.dataset.tools.jsonl_to_index` 构建 index 文件需提前先把 jsonl 文件命名为
$ python -m fm9g.dataset.tools.jsonl_to_index -p <数据集文件夹路径>
"""
import itertools
import math
import os
import pickle
import queue
import random
import threading
import time
import bmtrain as bmt
import h5py
import numpy
import numpy as np
import torch
try:
import msgspec
@ -22,13 +50,83 @@ except ModuleNotFoundError:
import torch
from torch.utils.data import Dataset
from typing_extensions import TypedDict
from .utils import Range
from fm9g.utils.bitset import BitSet
from fm9g.utils.bitset import bitset_diff
print_lock = threading.Lock()
def random_range(start, stop=None, step=None):
"""
Generator of non-repeated random permutation with the same inteface of python
`range`. Obtained from https://stackoverflow.com/a/53551417
The random.shuffle(list) and random.sample(list, len(list)) require
materialize the lists, which result in a long initalization period.
"""
if stop is None:
start, stop = 0, start
if step is None:
step = 1
# Use a mapping to convert a standard range into the desired range.
mapping = lambda i: (i * step) + start
# Compute the number of numbers in this range.
maximum = int(math.ceil((stop - start) / step))
if maximum == 0:
# early return with empty range
yield from ()
return
# Seed range with a random integer.
value = random.randint(0, maximum)
# Construct an offset, multiplier, and modulus for a linear
# congruential generator. These generators are cyclic and
# non-repeating when they maintain the properties:
#
# 1) "modulus" and "offset" are relatively prime.
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
# Pick a random odd-valued offset.
offset = random.randint(0, maximum) * 2 + 1
# Pick a multiplier 1 greater than a multiple of 4.
multiplier = 4 * (maximum // 4) + 1
# Pick a modulus just big enough to generate all numbers (power of 2).
modulus = int(2 ** math.ceil(math.log2(maximum)))
# Track how many random numbers have been returned.
found = 0
while found < maximum:
# If this is a valid value, yield it in generator fashion.
if value < maximum:
found += 1
yield mapping(value)
# Calculate the next value in the sequence.
value = (value * multiplier + offset) % modulus
class Range(object):
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
self.step = step
def __repr__(self):
return f"Range({self.start}, {self.stop}, {self.step})"
def iterate(self):
yield from range(self.start, self.stop, self.step)
def list(self):
return list(range(self.start, self.stop, self.step))
def subrange(self, split, nsplits):
# strided spliting range params
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
def random_iterate(self):
yield from random_range(self.start, self.stop, self.step)
def safe_print(*args, **kargs):
if "flush" in kargs:
flush = kargs["flush"]
@ -40,12 +138,15 @@ def safe_print(*args, **kargs):
def concurrent_info():
world_size, rank = bmt.world_size(), bmt.rank()
# world_size, rank = bmt.world_size(), bmt.rank()
world_size = bmt.config["world_size"] // bmt.config["tp_size"]
rank = bmt.config["topology"].tp_idx
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
nworkers, worker_id = 1, 1
else:
nworkers, worker_id = worker_info.num_workers, worker_info.id
# print("concurrent_info: (world_size, rank, nworkers, worker_id): {}".format((world_size, rank, nworkers, worker_id)))
return world_size, rank, nworkers, worker_id
@ -56,14 +157,45 @@ class IndexedDataset(Dataset):
self.max_retry = max_retry
self.retry_sleep = retry_sleep
self.bounds = None
self.h5file = None
self.build_index()
def size(self):
return self.bounds[-1]
def _build_index_h5(self):
index_path = os.path.join(self.path, "index.h5")
if os.path.getsize(index_path) > 104857600:
self.h5file = h5py.File(os.path.join(self.path, "index.h5"), "r")
self.bounds = self.h5file["index"]
else:
# only load index into memory when it is small (< 100 Mb)
# to avoid keeping to many file handlers
self.h5file = None
with h5py.File(index_path, "r") as hf:
self.bounds = np.array(hf["index"])
def __del__(self):
if self.h5file is not None:
self.h5file.close()
def build_index(self):
s = time.time()
txt_size = os.path.getsize(os.path.join(self.path, "index"))
if txt_size > 0.5 * 1024**3 and os.path.exists(os.path.join(self.path, "index.h5")):
source = "h5"
self._build_index_h5()
else:
source = "txt"
self._build_index_txt()
e = time.time()
bmt.print_rank("build_index_{} from {}, using {:.2f}s".format(source, self.path, e - s))
def _build_index_txt(self):
with open(os.path.join(self.path, "index"), "r") as fin:
self.bounds = [int(line) for line in fin]
self.nlines = len(self.bounds)
def safe_read(self, i_or_s, offset, size):
for retry in itertools.count():
@ -138,39 +270,10 @@ class IndexedDataset(Dataset):
class PrefetchDecodeDataset(IndexedDataset):
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
# Add safe decoding in iterator
def __init__(self, *args, decode=json_decode, **kargs):
def __init__(self, *args, decode=json_decode, allow_repeat=False, **kargs):
super().__init__(*args, **kargs)
self.decode = decode
self.lock = threading.Lock()
self.prev_used = set() # store previously used index in the checkpoint
self.used = set() # track locally used index
def state_dict(self, gathered=True):
if not self.prev_used and not self.used:
return {"prev_used": set()}
if gathered:
used = torch.tensor(list(self.used)).cuda()
size = torch.tensor(used.numel()).cuda()
max_size = bmt.distributed.all_reduce(size, op="max")
# allgather requires tensors having the same size
used = torch.cat([used, torch.full((max_size - size,), -100, device=used.device)], dim=-1)
all_used = bmt.distributed.all_gather(used).unique()
all_used = set(all_used.tolist())
if -100 in all_used:
all_used.remove(-100) # remove the padding value
all_used.union(self.prev_used)
return {"prev_used": all_used}
else:
return {"prev_used": self.prev_used.union(self.used)}
def load_state_dict(self, state):
with self.lock:
self.used = state.get("prev_used", set())
def reset(self):
with self.lock:
self.used = set()
self.prev_used = set()
self.allow_repeat = allow_repeat
def safe_decode(self, i, raw):
if raw is None:
@ -191,19 +294,23 @@ class PrefetchDecodeDataset(IndexedDataset):
else:
return self.safe_decode(key, raw)
def loader(self, q, lid, keys, stop):
def loader(self, q, lid, keys, stop, used=None):
# concurrent prefetching worker
if used is None:
used = BitSet()
try:
for key in keys:
if stop.is_set():
break
# key is either a slice or an integer index
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
with self.lock:
unused = set(index) - self.used - self.prev_used
unused = bitset_diff(set(index), used)
if not unused:
# skip used slice / item
continue
if not q.empty():
# avoid breaking the distributed file system with large io load
time.sleep(random.random() * 2)
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
raw = super().__getitem__(key)
if raw is None:
@ -217,14 +324,14 @@ class PrefetchDecodeDataset(IndexedDataset):
# signaling the end of iteration to the main thread
q.put(StopIteration(lid))
def _iterate(self, key_groups, nprefetch=1000):
def _iterate(self, key_groups, nprefetch=1000, used=None):
# helper function for concurrent prefetching
q = queue.Queue(maxsize=nprefetch)
stop = threading.Event()
alive = set()
try:
for lid, keys in enumerate(key_groups):
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop), daemon=True)
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop, used), daemon=True)
loader.start()
alive.add(lid)
while True:
@ -236,7 +343,7 @@ class PrefetchDecodeDataset(IndexedDataset):
break
else:
# new item will be put later, wait for a while
time.sleep(0.3)
time.sleep(0.1)
continue
if isinstance(item, StopIteration):
alive.remove(item.value)
@ -245,16 +352,13 @@ class PrefetchDecodeDataset(IndexedDataset):
data = self.safe_decode(i, raw)
if data is None:
continue
self.used.add(i)
yield data
# automatically reset states with graceful ends.
self.reset()
yield i, data
finally:
# ask daemon loaders to stop
stop.set()
def iterate(self, nthreads=3, prefetch_sample=100):
world_size, rank, nworkers, worker_id = concurrent_info()
def iterate(self, nthreads=3, prefetch_sample=100, used=None, process_group=None):
world_size, rank, nworkers, worker_id = concurrent_info(process_group)
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
raise ValueError(
@ -269,18 +373,27 @@ class PrefetchDecodeDataset(IndexedDataset):
r = r.subrange(split=worker_id, nsplits=nworkers)
# split index among multi-threaded loaders
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
for data in self._iterate(id_groups, nprefetch=prefetch_sample):
yield data
return self._iterate(id_groups, nprefetch=prefetch_sample, used=used)
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=1000):
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=500, used=None):
world_size, rank, nworkers, worker_id = concurrent_info()
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
if not self.allow_repeat:
raise ValueError(
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
f"please constrain either "
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
)
else:
duplicated_factor = math.ceil(nloaders / len(self))
# In this case, slice size is 1
r = Range(0, len(self), 1)
# split index among grouped multi-gpu workers
r = r.subrange(split=rank // duplicated_factor, nsplits=math.ceil(world_size / duplicated_factor))
# # split index among multi-threaded loaders
r = r.subrange(split=worker_id, nsplits=nworkers)
else:
nslices = int(math.ceil(len(self) / slice_size))
if nslices < nloaders:
@ -300,20 +413,21 @@ class PrefetchDecodeDataset(IndexedDataset):
slice_groups = [
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
]
for data in self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size):
yield data
return self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size, used=used)
class IndexedDatasetBuilder:
def __init__(self, path, overwrite=False):
self.path = path
self.index_path = os.path.join(self.path, "index")
self.index_path = os.path.join(self.path, "index.h5")
self.index_path_txt = os.path.join(self.path, "index")
self.data_path = os.path.join(self.path, "data.jsonl")
if not overwrite:
assert not os.path.exists(self.data_path)
assert not os.path.exists(self.index_path)
assert not os.path.exists(self.index_path_txt)
self.fout = None
self.starts = []
self.bounds = []
self.offset = 0
def __enter__(self):
@ -322,15 +436,17 @@ class IndexedDatasetBuilder:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.starts.append(self.offset)
with open(self.index_path, "w") as fout:
for s in self.starts:
fout.write(f"{s}\n")
self.bounds.append(self.offset)
with h5py.File(os.path.join(self.index_path), "w") as hf:
hf.create_dataset("index", data=self.bounds)
with open(self.index_path_txt, "w") as fout_txt:
for s in self.bounds:
fout_txt.write(f"{s}\n")
self.fout.close()
def put(self, data: dict):
s = json_encode(data) + b"\n"
self.starts.append(self.offset)
self.bounds.append(self.offset)
self.offset += len(s)
self.fout.write(s)

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pickle

View File

@ -0,0 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/08/07

View File

@ -1,14 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/07/27
import argparse
import torch
from tqdm import tqdm
from cpm.dataset import SimpleDataset
from cpm.dataset.indexed_dataset import IndexedDatasetBuilder
from fm9g.dataset import SimpleDataset
from fm9g.dataset.indexed_dataset import IndexedDatasetBuilder
def convert_cpm_data(cpm_path, out_path):
dataset = SimpleDataset(cpm_path, shuffle=False)
def convert_fm9g_data(fm9g_path, out_path):
dataset = SimpleDataset(fm9g_path, shuffle=False)
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
builder.put(dataset.read())
@ -16,7 +25,7 @@ def convert_cpm_data(cpm_path, out_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, help="Data path in CPM format.")
parser.add_argument("--input", "-i", required=True, help="Data path in fm9g format.")
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
args = parser.parse_args()
convert_cpm_data(args.input, args.output)
convert_fm9g_data(args.input, args.output)

View File

@ -1,3 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/08/07
import argparse
import os

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os
@ -251,7 +266,7 @@ def merge_dataset(dst: str, src: str):
_write_info_list(meta_path_dst, nw_info)
def to_cpm(src_data, dst_path, dst_name):
def to_fm9g(src_data, dst_path, dst_name):
if not os.path.exists(dst_path):
os.makedirs(dst_path)

View File

@ -0,0 +1,8 @@
{
"folders": [
{
"path": "../.."
}
],
"settings": {}
}

View File

@ -0,0 +1,105 @@
import torch
from transformers.configuration_utils import PretrainedConfig
class DragonflyConfig(PretrainedConfig):
model_type = "fm9g_dragonfly"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_key_value_heads": "num_kv_heads",
"hidden_act": "activate_fn",
"hidden_size": "dim_model",
"num_attention_heads": "num_heads",
"intermediate_size": "dim_ff",
"num_hidden_layers": "num_layers",
"vocab_size": "vocab_size",
"rms_norm_eps": "eps",
"scale_emb": "scale_emb",
"scale_depth": "scale_depth",
"scale": "scale",
"attention_scale": "attention_scale",
"qk_norm": "qk_norm",
"ffn_gated": "ffn_gated",
} # model specific to common
def __init__(
self,
vocab_size=122753, # TODO: do we need to change to 122880 = 960 * 128?
dim_model=4096,
num_heads=32,
num_kv_heads=32,
dim_head=128,
dim_ff=11008,
num_layers=32,
dropout_p=0.0,
activate_fn="silu",
scale=False,
scale_emb: float = 1.0,
scale_depth: float = -1,
dim_model_base: int = 256,
eps=1e-5,
init_std=0.02,
dtype="bf16",
base=10000,
qk_norm=False,
tie_lm_head=False,
max_length=8192,
pose_prob=0.0,
pose_scaling_factor=1,
rope_scaling_type="",
rope_scaling_factor=1,
orig_max_length=8192,
tp=0,
use_checkpoint=True,
**kwargs,
):
self.vocab_size = vocab_size
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.dim_head = dim_head
self.dim_ff = dim_ff
self.num_layers = num_layers
self.dropout_p = dropout_p
self.activate_fn = activate_fn
self.scale = scale
self.scale_emb = scale_emb
self._dtype = dtype
self.dim_model_base = dim_model_base
self.scale_depth = scale_depth
self.eps = eps
self.init_std = init_std
self.base = base
self.qk_norm = qk_norm
self.tie_lm_head = tie_lm_head
self.use_bfloat16 = True if self._dtype == "bf16" else False
self.pose_prob = pose_prob
self.pose_scaling_factor = pose_scaling_factor
self.rope_scaling_type = rope_scaling_type
self.rope_scaling_factor = rope_scaling_factor
self.max_length = max_length
self.orig_max_length = orig_max_length
self.use_checkpoint = use_checkpoint
print("use_checkpoint", self.use_checkpoint)
self.tp = tp
super().__init__(architectures=["fm9gDragonflyForCausalLM"])
@property
def scale_width(
self,
):
if self.scale:
return self.dim_model / self.dim_model_base
else:
return 1.0
@property
def dtype(
self,
): # -> Any | None:
if self._dtype == "bf16":
return torch.bfloat16
elif self._dtype == "fp16":
return torch.half
elif self._dtype == "float32":
return torch.float

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
from .pretrain_indexed import MixedIndexedDataset

View File

@ -0,0 +1,74 @@
import logging
from multiprocessing import Lock
from flask import Flask
from flask import jsonify
from flask import request
app = Flask(__name__)
# 获取 Werkzeug 日志记录器并设置日志级别
log = logging.getLogger("werkzeug")
log.setLevel(logging.WARNING)
class GlobalAvgTokensStat(object):
def __init__(self, decay_factor: float = 0.98):
self._avg_tokens = {}
self.decay_factor = decay_factor
self.lock = Lock()
self.task_locks = {}
def set_avg_tokens(self, task_name, avg_tokens):
self._register_task_lock_helper(task_name)
with self.task_locks[task_name]:
self._avg_tokens[task_name] = avg_tokens
def update_avg_tokens_by_ema(self, task_name, length):
self._register_task_lock_helper(task_name)
with self.task_locks[task_name]:
if task_name in self._avg_tokens and self._avg_tokens[task_name] > 0:
self._avg_tokens[task_name] = self._avg_tokens[task_name] * self.decay_factor + length * (
1 - self.decay_factor
)
else:
self._avg_tokens[task_name] = length
def get_avg_tokens(self, task_name):
self._register_task_lock_helper(task_name)
with self.task_locks[task_name]:
return self._avg_tokens.get(task_name, -1)
def _register_task_lock_helper(self, task_name):
if task_name not in self.task_locks:
with self.lock:
if task_name not in self.task_locks:
self.task_locks[task_name] = Lock()
global_avg_tokens_stat = GlobalAvgTokensStat()
@app.route("/avg_tokens/<path:task_name>", methods=["GET"])
def get_avg_tokens(task_name):
global global_avg_tokens_stat
avg_tokens = global_avg_tokens_stat.get_avg_tokens(task_name)
return jsonify({"avg_tokens": avg_tokens})
@app.route("/avg_tokens/<path:task_name>", methods=["POST"])
def set_avg_tokens(task_name):
global global_avg_tokens_stat
action = request.args.get("action", "update", type=str)
length = request.args.get("length", -1, type=int)
if action == "set":
global_avg_tokens_stat.set_avg_tokens(task_name, length)
elif action == "update":
global_avg_tokens_stat.update_avg_tokens_by_ema(task_name, length)
else:
raise ValueError(f"Unknown action: {action}")
return jsonify({"status": "ok"})
if __name__ == "__main__":
app.run(port=5000, debug=True)

View File

@ -0,0 +1,826 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
import copy
import ctypes
import functools
import importlib
import json
import logging
import os
import random
from collections import defaultdict
from collections import OrderedDict
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.shared_memory import SharedMemory
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from numpy.typing import NDArray
from fm9g.dataset import PrefetchDecodeDataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.vdc_sampling import van_der_corput
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
logger = logging.getLogger(__name__)
IGNORE_TGT = -100
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
if cfg_json_str is not None:
cfgs = json.loads(cfg_json_str)
else:
with open(cfg_path, "r", encoding="utf-8") as fin:
cfgs = json.load(fin)
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
path_dict = None
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
try:
with open(platform_config_path, "r") as f:
platform_cfg = json.load(f)
path_dict = platform_cfg["dataset_map"]
if bmt.rank() == 0:
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
except Exception as e:
if bmt.rank() == 0:
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
task_name2dataset_name = dict()
for idx, cfg in enumerate(cfgs):
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
# to be delibrately annoying :)
if cfg["task_name"] in task_name2dataset_name:
raise ValueError(
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
)
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
assert "path" in cfg and isinstance(cfg["path"], str)
# if path_dict is not None:
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
# dealing with optional configs
if "weight" in cfg:
assert isinstance(cfg["weight"], (float, int))
else:
cfg["weight"] = 1.0
if "oversize_rule" in cfg:
assert cfg["oversize_rule"] in ("drop", "head", "segment")
else:
cfg["oversize_rule"] = "segment"
if "transforms" in cfg:
assert isinstance(cfg["transforms"], str)
# dealing with relative path
if not cfg["transforms"].startswith("/"):
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
if not cfg["transforms"]:
cfg["transforms"] = None
else:
cfg["transforms"] = None
if "incontext_weight" in cfg:
assert isinstance(cfg["incontext_weight"], (list, tuple))
else:
cfg["incontext_weight"] = [1.0]
cfg["id"] = idx
# dataset and iterator will be built
return cfgs
def data2ids(data, tokenizer, max_length):
text = "\n".join(
[
data.get("title", "").strip(),
data.get("question", "").strip(),
data.get("answer", "").strip(),
data.get("abstract", "").strip(),
data.get("text", "").strip(),
data.get("code", "").strip(),
]
).strip()
if not text:
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
yield from ()
return
# suppress the annoying warning from tokenizer
ids = (
[tokenizer.bos_token_id]
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
+ [tokenizer.eos_token_id]
)
src_ids = ids[0:-1]
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
if len(src_ids) > max_length:
for st in range(0, len(src_ids), max_length):
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
else:
yield src_ids, tgt_ids
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
assert oversize_rule in ("drop", "head", "segment")
if data is None:
yield from ()
return
if "output" not in data or not data["output"]:
yield from ()
return
if "input" not in data or data["input"] is None:
data["input"] = ""
src_ids = [tokenizer.bos_token_id]
tgt_ids = []
has_input = False
is_segment_reenter = False
# Use incremental tokenization to avoid waiting for a long document
MAX_CHUNK_LENGTH = max_length * 10
for part in ("input", "output"):
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
while l < len(data[part]):
try:
current_slice = data[part][l:r]
if not current_slice:
break
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
except:
print("Error in data[part][l:r] {}".format(data))
yield from ()
return
if part == "input":
if len(token_ids) > 0:
has_input = True
if len(token_ids) >= max_length - 2: # input len must < max_length
yield from ()
return
src_ids.extend(token_ids)
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
l = r
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
else:
if len(token_ids) + len(tgt_ids) >= max_length:
if oversize_rule == "drop":
yield from ()
return
elif oversize_rule == "head":
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
elif oversize_rule == "segment":
instruction_rest_space = max_length - 1 - len(token_ids)
if has_input: # is instruction data
if (
do_compact
and len(src_ids) >= 128 # avoid too short instruction info lost
and instruction_rest_space / len(src_ids) > 0.8
): # can be squeezed into max length
inputs_len = len(src_ids)
keep_len = instruction_rest_space // 2
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else: # else use head rule
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
else: # normal segment
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids)
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
yield src_ids[:max_length], tgt_ids[:max_length]
src_ids = src_ids[max_length:]
tgt_ids = tgt_ids[max_length:]
# sliding input str window
consumed_str = tokenizer.decode(selected_token_ids)
l += len(consumed_str)
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
is_segment_reenter = True
else:
if (is_segment_reenter and len(token_ids) > 8) or (
not is_segment_reenter and len(token_ids) > 0
): # is segmented LM data
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else:
yield from ()
return
class SegmentedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg,
tokenizer,
max_length=1024,
transform_func=None,
nthreads=1,
prefetch_slice=3,
slice_size=500,
do_compact=False,
):
super(SegmentedDataset, self).__init__()
self.segment = functools.partial(
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
)
self.cfg = cfg
self.max_length = max_length
self.nthreads = nthreads
self.transform_func = transform_func
self.prefetch_slice = prefetch_slice
self.slice_size = slice_size
self.abs_weight = cfg.get("abs_weight", None)
self.task_name = cfg["task_name"]
self.dataset_name = cfg["dataset_name"]
self.oversize_rule = cfg["oversize_rule"]
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
self.exhausted = False
self.iterator = None
self.counter = 0
self.allow_repeat = cfg.get("allow_repeat", True)
self.used = BitSet()
self.init_ave_tokens()
def init_ave_tokens(
self,
):
try:
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
except FileNotFoundError:
bmt.print_rank(
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
)
shm = SharedMemory(
create=True,
size=ctypes.sizeof(ctypes.c_float),
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
)
# 使用共享内存
shared_value = ctypes.c_float.from_buffer(shm.buf)
_ave_tokens = self.cfg.get(
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
)
if _ave_tokens > self.max_length:
_ave_tokens = self.max_length
bmt.print_rank(
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
_ave_tokens, self.max_length
)
)
shared_value.value = _ave_tokens
# 不再需要 shared_value 时,删除引用
del shared_value
# 现在可以安全地关闭共享内存
shm.close()
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
@property
def ave_tokens(
self,
):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
tmp = shared_value.value
del shared_value
existing_shm.close()
return tmp
def ave_tokens_update(self, length):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
if shared_value.value < 0:
shared_value.value = float(length)
else:
shared_value.value = 0.98 * shared_value.value + 0.02 * length
del shared_value
existing_shm.close()
def size(self):
return self.dataset.size()
def __iter__(self):
self.iterate()
return self
def reset(self):
self.exhausted = False
if self.iterator is not None:
self.iterator.close()
self.iterator = None
self.used = BitSet()
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
def transform(self, data: dict) -> dict:
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
weight = weight / weight.sum()
num_incontext = np.random.choice(weight.shape[0], p=weight)
return self.transform_func(data, num_incontext, random.Random())
def segment_iterate(self, sample_iter):
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
for src_ids, tgt_ids in self.segment(self.transform(data)):
self.ave_tokens_update(len(src_ids)) # 0 for input ids
yield src_ids, tgt_ids, index
def iterate(self):
# make the dataset itself an iterator
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
self.iterator = self.segment_iterate(sample_iter)
def __next__(self):
# advance the task iterator
if self.iterator is None:
self.iterate()
try:
return next(self.iterator)
except StopIteration:
self.exhausted = True
return None
def load_state_dict(self, state_dict):
if state_dict.get("exhausted", False):
self.exhausted = True
self.used = BitSet()
else:
used = state_dict.get("used", BitSet())
if len(used) == len(self.dataset):
self.exhausted = True
self.used = BitSet()
else:
self.exhausted = False
self.used = used
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
def state_dict(self):
if len(self.used) == len(self.dataset):
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
else:
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
def update_state(self, indice):
self.used.update(indice)
class MixedIndexedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg_path: str,
cfg_json_str,
tokenizer,
max_length,
weight_by_size: bool = True,
nthreads=5,
prefetch_slice=100,
parallel_loading=False,
vdc_sampling=False,
update_weights_frequency=1,
seed=42,
):
super(MixedIndexedDataset, self).__init__()
self.set_seed(seed + bmt.rank())
self.weight_by_size = weight_by_size
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.bos_token_id = self.tokenizer.bos_token_id
self.path2transform = dict()
self.task_dict = OrderedDict()
self.nthreads = nthreads
self.prefetch_slice = prefetch_slice
# useful for indexing
self.tasks = []
self.names = []
# ending of iteration
self.remain = 0
self.max_length = max_length
self.vdc_sampling = vdc_sampling
if self.vdc_sampling:
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
self.update_weights_frequency = update_weights_frequency
self.path2transform = dict()
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
if parallel_loading:
self.parallel_load(cfgs, max_workers=None)
else:
self.sequential_load(cfgs)
self.weights = None
self.update_weights()
def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_task(self, cfg):
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
return task
def sequential_load(self, cfgs):
self.cfgs = cfgs
for cfg in cfgs:
# python3.7 and later preserves insertion order to dictionary
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
self.task_dict[task.task_name] = task
self.tasks.append(task)
self.names.append(task.task_name)
self.remain += 1
self.weights = None
self.update_weights()
def load_state_dict(self, state_dict):
missing_keys = []
for name, task in self.task_dict.items():
if name in state_dict:
task.load_state_dict(state_dict[name])
else:
missing_keys.append(name)
self.update_weights()
return missing_keys
def save_state_dict(self, path):
state_dict = {}
for name, task in self.task_dict.items():
_state_dict = task.state_dict()
if isinstance(_state_dict["used"], BitSet):
bitset = _state_dict["used"]
_file_name = bitset.save(path)
_state_dict["used"] = _file_name
state_dict[name] = _state_dict
else:
state_dict[name] = task.state_dict()
torch.save(state_dict, path)
logger.info("Dataset state saved")
def update_states(self, task_ids, indice):
is_dict = isinstance(indice, dict)
uniq = torch.unique(task_ids)
for idx in uniq:
idx = idx.item()
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
self.tasks[idx].update_state(indexes)
def get_transform_func(self, module_name: str, transform_script_path):
if transform_script_path is None:
# allow null transform
return lambda data, num_incontext, rand: data
module_name = "fm9g_live.transforms.{}".format(module_name)
if transform_script_path not in self.path2transform:
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
spec = importlib.util.spec_from_loader(loader.name, loader)
if spec is None:
raise RuntimeError("Spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self.path2transform[transform_script_path] = {
"loader": loader,
"module": mod,
"last_mtime": 0,
}
transform_script_info = self.path2transform[transform_script_path]
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
if curr_mtime > transform_script_info["last_mtime"]:
transform_script_info["last_mtime"] = curr_mtime
transform_script_info["loader"].exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
return transform_func
def update_weights(self):
task0 = self.tasks[0]
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
weights = []
for task in self.tasks:
if task.exhausted:
weights.append(0)
else:
if task.ave_tokens == -1:
weights.append(task.abs_weight / self.max_length)
else:
weights.append(task.abs_weight / task.ave_tokens)
weights = np.array(weights)
else:
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
if self.weight_by_size:
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
weights *= sizes
self.weights = weights / weights.sum()
def __iter__(self):
for task in self.tasks:
task.iterate()
return self
def __next__(self):
step = 1
while True:
if self.remain == 0:
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
raise StopIteration
if self.vdc_sampling:
idx = next(self.vdc_gen)(self.weights)
else:
idx = np.random.choice(len(self.weights), p=self.weights)
data = next(self.tasks[idx])
if step % self.update_weights_frequency == 0:
self.update_weights()
if data is None:
if self.tasks[idx].allow_repeat:
# _runtime_ave = self.tasks[idx].ave_tokens
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
# self.tasks[idx] = SegmentedDataset(
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
# )
# self.tasks[idx].ave_tokens_update(_runtime_ave)
self.tasks[idx].reset()
else:
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
self.tasks[idx].exhaust = True
self.remain -= 1
continue
step += 1
return dict(
task_id=idx,
input=data[0],
target=data[1],
index=data[2],
is_long=self.tasks[idx].cfg.get("is_long", False),
)
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
self.max_total_length = batch_size * max_length
self.batch_size = 1
# setting compact=True concats segments orignated from the same input
# into a long sequence. the relative order of segments should be preserved
# in mixed_dataset, e.g.,
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
self.compact = compact
self.total_length = 0
self.task2seqs = defaultdict(list)
self.mixed_dataset = mixed_dataset
self._max_length = max_length
self._pose_prob = pose_prob
self._pose_scaling_factor = pose_scaling_factor
if self._pose_prob > 0.0:
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
else:
self._scaled_max_length = max_length
def put(self, sample):
self.total_length += len(sample["target"])
task_id = sample["task_id"]
if self.compact and self.task2seqs[task_id]:
last = self.task2seqs[task_id][-1]
if last["target"][-1] != self.mixed_dataset.eos_token_id:
# concatenate sequantial segments for longer context modeling: why not?
last["input"].extend(sample["input"])
last["target"].extend(sample["target"])
return
self.task2seqs[task_id].append(sample)
def _pose_preprocess(
self,
input_ids: NDArray[np.int32],
) -> NDArray[np.int32]:
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
"""
len_chunk = min(len(input_ids), self._max_length)
len_input = len(input_ids)
# Chunk input randomly to fit max_length if needed
lt1 = 0
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
# Generate PoSE position ids
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
len_position_ids = len(position_ids)
lt = 0
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
position_ids[: rt1 - lt1] += lt
position_ids[rt1 - lt1 :] += rt
return position_ids
def pop(self):
indexes = defaultdict(list)
lengths = []
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
span_begin = 0
for samples in self.task2seqs.values():
while samples:
sample = samples.pop()
span_end = span_begin + len(sample["input"])
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
_span_position_ids = self._pose_preprocess(sample["input"])
else:
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
lengths.append(len(sample["target"]))
indexes[int(sample["task_id"])].append(sample["index"])
self.total_length -= len(sample["target"])
span_begin = span_end
cu_seqlens = torch.cat(
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
dim=0,
).int()
batch = {
"inputs": inputs,
"targets": targets,
"task_ids": task_ids,
"indexes": indexes,
# adhere to flash attention interface
"cu_seqlens": cu_seqlens,
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
"lengths": torch.tensor(sum(lengths)).int(),
"task_names": self.mixed_dataset.names,
"position_ids": position_ids,
}
return batch
def will_be_full(self, sample):
return self.total_length + len(sample["target"]) > self.max_total_length
def __iter__(self):
for sample in self.mixed_dataset:
if self.will_be_full(sample):
yield self.pop()
self.put(sample)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader, tp_size=1, tp_rank=0):
self.loader = iter(loader)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
if self.tp_size > 1:
if self.tp_rank == 0:
data = next(self.loader)
print("Rank {}, Preload data done.".format(bmt.rank()))
d = {}
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
for key in data.keys():
if isinstance(data[key], torch.Tensor):
np_cur_data = data[key].cpu().numpy()
bs = np_cur_data.tobytes()
fb.write(bs)
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
elif isinstance(data[key], np.ndarray):
bs = data[key].tobytes()
fb.write(bs)
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
else:
d[key] = data[key]
try:
_ = json.dumps(d)
except TypeError:
print(d)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
json.dump(d, f)
bmt.synchronize()
if self.tp_rank != 0:
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
data = json.load(f)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
bs = fb.read()
offset = 0
for key in data.keys():
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
nw_offset = offset + data[key][2]
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
data[key][3:]
)
offset = nw_offset
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
nw_offset = offset + data[key][2]
data[key] = torch.from_numpy(
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
.reshape(data[key][3:])
.copy()
)
offset = nw_offset
self.data = data
else:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key].record_stream(torch.cuda.current_stream())
data = copy.deepcopy(self.data)
self.preload()
return data
def __iter__(self):
return self

View File

@ -0,0 +1,828 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
import copy
import ctypes
import functools
import importlib
import json
import logging
import os
import random
from collections import defaultdict
from collections import OrderedDict
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.shared_memory import SharedMemory
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from numpy.typing import NDArray
from fm9g.dataset import PrefetchDecodeDataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.vdc_sampling import van_der_corput
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
logger = logging.getLogger(__name__)
IGNORE_TGT = -100
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
if cfg_json_str is not None:
cfgs = json.loads(cfg_json_str)
else:
with open(cfg_path, "r", encoding="utf-8") as fin:
cfgs = json.load(fin)
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
path_dict = None
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
try:
with open(platform_config_path, "r") as f:
platform_cfg = json.load(f)
path_dict = platform_cfg["dataset_map"]
if bmt.rank() == 0:
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
except Exception as e:
if bmt.rank() == 0:
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
task_name2dataset_name = dict()
for idx, cfg in enumerate(cfgs):
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
# to be delibrately annoying :)
if cfg["task_name"] in task_name2dataset_name:
raise ValueError(
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
)
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
assert "path" in cfg and isinstance(cfg["path"], str)
# if path_dict is not None:
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
# dealing with optional configs
if "weight" in cfg:
assert isinstance(cfg["weight"], (float, int))
else:
cfg["weight"] = 1.0
if "oversize_rule" in cfg:
assert cfg["oversize_rule"] in ("drop", "head", "segment")
else:
cfg["oversize_rule"] = "segment"
if "transforms" in cfg:
assert isinstance(cfg["transforms"], str)
# dealing with relative path
if not cfg["transforms"].startswith("/"):
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
if not cfg["transforms"]:
cfg["transforms"] = None
else:
cfg["transforms"] = None
if "incontext_weight" in cfg:
assert isinstance(cfg["incontext_weight"], (list, tuple))
else:
cfg["incontext_weight"] = [1.0]
cfg["id"] = idx
# dataset and iterator will be built
return cfgs
def data2ids(data, tokenizer, max_length):
text = "\n".join(
[
data.get("title", "").strip(),
data.get("question", "").strip(),
data.get("answer", "").strip(),
data.get("abstract", "").strip(),
data.get("text", "").strip(),
data.get("code", "").strip(),
]
).strip()
if not text:
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
yield from ()
return
# suppress the annoying warning from tokenizer
ids = (
[tokenizer.bos_id]
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
+ [tokenizer.eos_id]
)
src_ids = ids[0:-1]
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
if len(src_ids) > max_length:
for st in range(0, len(src_ids), max_length):
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
else:
yield src_ids, tgt_ids
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
assert oversize_rule in ("drop", "head", "segment")
if data is None:
yield from ()
return
if "output" not in data or not data["output"]:
yield from ()
return
if "input" not in data or data["input"] is None:
data["input"] = ""
src_ids = [tokenizer.bos_id]
tgt_ids = []
has_input = False
is_segment_reenter = False
# Use incremental tokenization to avoid waiting for a long document
MAX_CHUNK_LENGTH = max_length * 10
for part in ("input", "output"):
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
while l < len(data[part]):
try:
current_slice = data[part][l:r]
if not current_slice:
break
#token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
token_ids = tokenizer.encode(current_slice)
except:
#print("Error in data[part][l:r] {}".format(data))
yield from ()
return
if part == "input":
if len(token_ids) > 0:
has_input = True
if len(token_ids) >= max_length - 2: # input len must < max_length
yield from ()
return
src_ids.extend(token_ids)
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
l = r
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
else:
if len(token_ids) + len(tgt_ids) >= max_length:
if oversize_rule == "drop":
yield from ()
return
elif oversize_rule == "head":
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
elif oversize_rule == "segment":
instruction_rest_space = max_length - 1 - len(token_ids)
if has_input: # is instruction data
if (
do_compact
and len(src_ids) >= 128 # avoid too short instruction info lost
and instruction_rest_space / len(src_ids) > 0.8
): # can be squeezed into max length
inputs_len = len(src_ids)
keep_len = instruction_rest_space // 2
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_id)
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else: # else use head rule
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
else: # normal segment
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids)
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
yield src_ids[:max_length], tgt_ids[:max_length]
src_ids = src_ids[max_length:]
tgt_ids = tgt_ids[max_length:]
# sliding input str window
consumed_str = tokenizer.decode(selected_token_ids)
l += len(consumed_str)
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
is_segment_reenter = True
else:
if (is_segment_reenter and len(token_ids) > 8) or (
not is_segment_reenter and len(token_ids) > 0
): # is segmented LM data
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_id)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else:
yield from ()
return
class SegmentedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg,
tokenizer,
max_length=1024,
transform_func=None,
nthreads=1,
prefetch_slice=3,
slice_size=500,
do_compact=False,
):
super(SegmentedDataset, self).__init__()
self.segment = functools.partial(
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
)
self.cfg = cfg
self.max_length = max_length
self.nthreads = nthreads
self.transform_func = transform_func
self.prefetch_slice = prefetch_slice
self.slice_size = slice_size
self.abs_weight = cfg.get("abs_weight", None)
self.task_name = cfg["task_name"]
self.dataset_name = cfg["dataset_name"]
self.oversize_rule = cfg["oversize_rule"]
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
self.exhausted = False
self.iterator = None
self.counter = 0
self.allow_repeat = cfg.get("allow_repeat", True)
self.used = BitSet()
self.init_ave_tokens()
def init_ave_tokens(
self,
):
try:
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
except FileNotFoundError:
bmt.print_rank(
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
)
shm = SharedMemory(
create=True,
size=ctypes.sizeof(ctypes.c_float),
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
)
# 使用共享内存
shared_value = ctypes.c_float.from_buffer(shm.buf)
_ave_tokens = self.cfg.get(
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
)
if _ave_tokens > self.max_length:
_ave_tokens = self.max_length
bmt.print_rank(
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
_ave_tokens, self.max_length
)
)
shared_value.value = _ave_tokens
# 不再需要 shared_value 时,删除引用
del shared_value
# 现在可以安全地关闭共享内存
shm.close()
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
@property
def ave_tokens(
self,
):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
tmp = shared_value.value
del shared_value
existing_shm.close()
return tmp
def ave_tokens_update(self, length):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
if shared_value.value < 0:
shared_value.value = float(length)
else:
shared_value.value = 0.98 * shared_value.value + 0.02 * length
del shared_value
existing_shm.close()
def size(self):
return self.dataset.size()
def __iter__(self):
self.iterate()
return self
def reset(self):
self.exhausted = False
if self.iterator is not None:
self.iterator.close()
self.iterator = None
self.used = BitSet()
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
def transform(self, data: dict) -> dict:
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
weight = weight / weight.sum()
num_incontext = np.random.choice(weight.shape[0], p=weight)
return self.transform_func(data, num_incontext, random.Random())
def segment_iterate(self, sample_iter):
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
for src_ids, tgt_ids in self.segment(self.transform(data)):
self.ave_tokens_update(len(src_ids)) # 0 for input ids
yield src_ids, tgt_ids, index
def iterate(self):
# make the dataset itself an iterator
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
self.iterator = self.segment_iterate(sample_iter)
def __next__(self):
# advance the task iterator
if self.iterator is None:
self.iterate()
try:
return next(self.iterator)
except StopIteration:
self.exhausted = True
return None
def load_state_dict(self, state_dict):
if state_dict.get("exhausted", False):
self.exhausted = True
self.used = BitSet()
else:
used = state_dict.get("used", BitSet())
if len(used) == len(self.dataset):
self.exhausted = True
self.used = BitSet()
else:
self.exhausted = False
self.used = used
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
def state_dict(self):
if len(self.used) == len(self.dataset):
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
else:
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
def update_state(self, indice):
self.used.update(indice)
class MixedIndexedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg_path: str,
cfg_json_str,
tokenizer,
max_length,
weight_by_size: bool = True,
nthreads=5,
prefetch_slice=100,
parallel_loading=False,
vdc_sampling=False,
update_weights_frequency=1,
seed=42,
):
super(MixedIndexedDataset, self).__init__()
self.set_seed(seed + bmt.rank())
self.weight_by_size = weight_by_size
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_id
self.bos_token_id = self.tokenizer.bos_id
self.path2transform = dict()
self.task_dict = OrderedDict()
self.nthreads = nthreads
self.prefetch_slice = prefetch_slice
# useful for indexing
self.tasks = []
self.names = []
# ending of iteration
self.remain = 0
self.max_length = max_length
self.vdc_sampling = vdc_sampling
if self.vdc_sampling:
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
self.update_weights_frequency = update_weights_frequency
self.path2transform = dict()
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
if parallel_loading:
self.parallel_load(cfgs, max_workers=None)
else:
self.sequential_load(cfgs)
self.weights = None
self.update_weights()
def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_task(self, cfg):
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
return task
def sequential_load(self, cfgs):
self.cfgs = cfgs
for cfg in cfgs:
# python3.7 and later preserves insertion order to dictionary
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
self.task_dict[task.task_name] = task
self.tasks.append(task)
self.names.append(task.task_name)
self.remain += 1
self.weights = None
self.update_weights()
def load_state_dict(self, state_dict):
missing_keys = []
for name, task in self.task_dict.items():
if name in state_dict:
task.load_state_dict(state_dict[name])
else:
missing_keys.append(name)
self.update_weights()
return missing_keys
def save_state_dict(self, path):
state_dict = {}
for name, task in self.task_dict.items():
_state_dict = task.state_dict()
if isinstance(_state_dict["used"], BitSet):
bitset = _state_dict["used"]
_file_name = bitset.save(path)
_state_dict["used"] = _file_name
state_dict[name] = _state_dict
else:
state_dict[name] = task.state_dict()
torch.save(state_dict, path)
logger.info("Dataset state saved")
def update_states(self, task_ids, indice):
is_dict = isinstance(indice, dict)
uniq = torch.unique(task_ids)
for idx in uniq:
idx = idx.item()
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
self.tasks[idx].update_state(indexes)
def get_transform_func(self, module_name: str, transform_script_path):
if transform_script_path is None:
# allow null transform
return lambda data, num_incontext, rand: data
module_name = "fm9g_live.transforms.{}".format(module_name)
if transform_script_path not in self.path2transform:
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
spec = importlib.util.spec_from_loader(loader.name, loader)
if spec is None:
raise RuntimeError("Spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self.path2transform[transform_script_path] = {
"loader": loader,
"module": mod,
"last_mtime": 0,
}
transform_script_info = self.path2transform[transform_script_path]
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
if curr_mtime > transform_script_info["last_mtime"]:
transform_script_info["last_mtime"] = curr_mtime
transform_script_info["loader"].exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
return transform_func
def update_weights(self):
task0 = self.tasks[0]
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
weights = []
for task in self.tasks:
if task.exhausted:
weights.append(0)
else:
if task.ave_tokens == -1:
weights.append(task.abs_weight / self.max_length)
else:
weights.append(task.abs_weight / task.ave_tokens)
weights = np.array(weights)
else:
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
if self.weight_by_size:
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
weights *= sizes
self.weights = weights / weights.sum()
def __iter__(self):
for task in self.tasks:
task.iterate()
return self
def __next__(self):
step = 1
while True:
if self.remain == 0:
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
raise StopIteration
if self.vdc_sampling:
idx = next(self.vdc_gen)(self.weights)
else:
idx = np.random.choice(len(self.weights), p=self.weights)
data = next(self.tasks[idx])
if step % self.update_weights_frequency == 0:
self.update_weights()
if data is None:
if self.tasks[idx].allow_repeat:
# _runtime_ave = self.tasks[idx].ave_tokens
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
# self.tasks[idx] = SegmentedDataset(
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
# )
# self.tasks[idx].ave_tokens_update(_runtime_ave)
self.tasks[idx].reset()
else:
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
self.tasks[idx].exhaust = True
self.remain -= 1
continue
step += 1
return dict(
task_id=idx,
input=data[0],
target=data[1],
index=data[2],
is_long=self.tasks[idx].cfg.get("is_long", False),
)
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
self.max_total_length = batch_size * max_length
self.batch_size = 1
# setting compact=True concats segments orignated from the same input
# into a long sequence. the relative order of segments should be preserved
# in mixed_dataset, e.g.,
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
self.compact = compact
self.total_length = 0
self.task2seqs = defaultdict(list)
self.mixed_dataset = mixed_dataset
self._max_length = max_length
self._pose_prob = pose_prob
self._pose_scaling_factor = pose_scaling_factor
if self._pose_prob > 0.0:
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
else:
self._scaled_max_length = max_length
def put(self, sample):
self.total_length += len(sample["target"])
task_id = sample["task_id"]
if self.compact and self.task2seqs[task_id]:
last = self.task2seqs[task_id][-1]
if last["target"][-1] != self.mixed_dataset.eos_token_id:
# concatenate sequantial segments for longer context modeling: why not?
last["input"].extend(sample["input"])
last["target"].extend(sample["target"])
return
self.task2seqs[task_id].append(sample)
def _pose_preprocess(
self,
input_ids: NDArray[np.int32],
) -> NDArray[np.int32]:
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
"""
len_chunk = min(len(input_ids), self._max_length)
len_input = len(input_ids)
# Chunk input randomly to fit max_length if needed
lt1 = 0
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
# Generate PoSE position ids
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
len_position_ids = len(position_ids)
lt = 0
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
position_ids[: rt1 - lt1] += lt
position_ids[rt1 - lt1 :] += rt
return position_ids
def pop(self):
indexes = defaultdict(list)
lengths = []
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
span_begin = 0
for samples in self.task2seqs.values():
while samples:
sample = samples.pop()
span_end = span_begin + len(sample["input"])
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
_span_position_ids = self._pose_preprocess(sample["input"])
else:
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
lengths.append(len(sample["target"]))
indexes[int(sample["task_id"])].append(sample["index"])
self.total_length -= len(sample["target"])
span_begin = span_end
cu_seqlens = torch.cat(
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
dim=0,
).int()
batch = {
"inputs": inputs,
"targets": targets,
"task_ids": task_ids,
"indexes": indexes,
# adhere to flash attention interface
"cu_seqlens": cu_seqlens,
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
"lengths": torch.tensor(sum(lengths)).int(),
"task_names": self.mixed_dataset.names,
"position_ids": position_ids,
}
return batch
def will_be_full(self, sample):
return self.total_length + len(sample["target"]) > self.max_total_length
def __iter__(self):
for sample in self.mixed_dataset:
if self.will_be_full(sample):
yield self.pop()
self.put(sample)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader, tp_size=1, tp_rank=0):
self.loader = iter(loader)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
if self.tp_size > 1:
if self.tp_rank == 0:
data = next(self.loader)
print("Rank {}, Preload data done.".format(bmt.rank()))
d = {}
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
for key in data.keys():
if isinstance(data[key], torch.Tensor):
np_cur_data = data[key].cpu().numpy()
bs = np_cur_data.tobytes()
fb.write(bs)
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
elif isinstance(data[key], np.ndarray):
bs = data[key].tobytes()
fb.write(bs)
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
else:
d[key] = data[key]
try:
_ = json.dumps(d)
except TypeError:
print(d)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
json.dump(d, f)
bmt.synchronize()
if self.tp_rank != 0:
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
data = json.load(f)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
bs = fb.read()
offset = 0
for key in data.keys():
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
nw_offset = offset + data[key][2]
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
data[key][3:]
)
offset = nw_offset
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
nw_offset = offset + data[key][2]
data[key] = torch.from_numpy(
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
.reshape(data[key][3:])
.copy()
)
offset = nw_offset
self.data = data
else:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key].record_stream(torch.cuda.current_stream())
data = copy.deepcopy(self.data)
self.preload()
return data
def __iter__(self):
return self

View File

@ -0,0 +1,827 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
import copy
import ctypes
import functools
import importlib
import importlib.util
import json
import logging
import multiprocessing as mp
import os
import random
from collections import defaultdict
from collections import OrderedDict
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.shared_memory import SharedMemory
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import requests
import torch
from numpy.typing import NDArray
from tenacity import retry
from tenacity import stop_after_attempt
from tenacity import wait_fixed
from tenacity import wait_random
from fm9g.dataset import PrefetchDecodeDataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.vdc_sampling import van_der_corput
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
from .flask_ps import app as flask_ps
logger = logging.getLogger(__name__)
IGNORE_TGT = -100
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
if cfg_json_str is not None:
cfgs = json.loads(cfg_json_str)
else:
with open(cfg_path, "r", encoding="utf-8") as fin:
cfgs = json.load(fin)
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
path_dict = None
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
try:
with open(platform_config_path, "r") as f:
platform_cfg = json.load(f)
path_dict = platform_cfg["dataset_map"]
if bmt.rank() == 0:
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
except Exception as e:
if bmt.rank() == 0:
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
task_name2dataset_name = dict()
for idx, cfg in enumerate(cfgs):
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
# to be delibrately annoying :)
if cfg["task_name"] in task_name2dataset_name:
raise ValueError(
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
)
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
assert "path" in cfg and isinstance(cfg["path"], str)
# if path_dict is not None:
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
# dealing with optional configs
if "weight" in cfg:
assert isinstance(cfg["weight"], (float, int))
else:
cfg["weight"] = 1.0
if "oversize_rule" in cfg:
assert cfg["oversize_rule"] in ("drop", "head", "segment")
else:
cfg["oversize_rule"] = "segment"
if "transforms" in cfg:
assert isinstance(cfg["transforms"], str)
# dealing with relative path
if not cfg["transforms"].startswith("/"):
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
if not cfg["transforms"]:
cfg["transforms"] = None
else:
cfg["transforms"] = None
if "incontext_weight" in cfg:
assert isinstance(cfg["incontext_weight"], (list, tuple))
else:
cfg["incontext_weight"] = [1.0]
cfg["id"] = idx
# dataset and iterator will be built
return cfgs
def data2ids(data, tokenizer, max_length):
text = "\n".join(
[
data.get("title", "").strip(),
data.get("question", "").strip(),
data.get("answer", "").strip(),
data.get("abstract", "").strip(),
data.get("text", "").strip(),
data.get("code", "").strip(),
]
).strip()
if not text:
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
yield from ()
return
# suppress the annoying warning from tokenizer
ids = (
[tokenizer.bos_token_id]
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
+ [tokenizer.eos_token_id]
)
src_ids = ids[0:-1]
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
if len(src_ids) > max_length:
for st in range(0, len(src_ids), max_length):
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
else:
yield src_ids, tgt_ids
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
assert oversize_rule in ("drop", "head", "segment")
if data is None:
yield from ()
return
if "output" not in data or not data["output"]:
yield from ()
return
if "input" not in data:
data["input"] = ""
src_ids = [tokenizer.bos_token_id]
tgt_ids = []
has_input = False
is_segment_reenter = False
# Use incremental tokenization to avoid waiting for a long document
MAX_CHUNK_LENGTH = max_length * 10
for part in ("input", "output"):
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
while l < len(data[part]):
current_slice = data[part][l:r]
if not current_slice:
break
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
if part == "input":
if len(token_ids) > 0:
has_input = True
if len(token_ids) >= max_length - 2: # input len must < max_length
yield from ()
return
src_ids.extend(token_ids)
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
l = r
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
else:
if len(token_ids) + len(tgt_ids) >= max_length:
if oversize_rule == "drop":
yield from ()
return
elif oversize_rule == "head":
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
elif oversize_rule == "segment":
instruction_rest_space = max_length - 1 - len(token_ids)
if has_input: # is instruction data
if (
do_compact
and len(src_ids) >= 128 # avoid too short instruction info lost
and instruction_rest_space / len(src_ids) > 0.8
): # can be squeezed into max length
inputs_len = len(src_ids)
keep_len = instruction_rest_space // 2
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else: # else use head rule
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
else: # normal segment
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids)
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
yield src_ids[:max_length], tgt_ids[:max_length]
src_ids = src_ids[max_length:]
tgt_ids = tgt_ids[max_length:]
# sliding input str window
consumed_str = tokenizer.decode(selected_token_ids)
l += len(consumed_str)
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
is_segment_reenter = True
else:
if (is_segment_reenter and len(token_ids) > 8) or (
not is_segment_reenter and len(token_ids) > 0
): # is segmented LM data
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else:
yield from ()
return
class SegmentedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg,
tokenizer,
max_length=1024,
transform_func=None,
nthreads=1,
prefetch_slice=3,
slice_size=500,
do_compact=False,
is_local=False,
):
def get_full_qualified_name(func):
module_name = func.__module__
qual_name = func.__qualname__
return f"{module_name}.{qual_name}"
super(SegmentedDataset, self).__init__()
self.segment = functools.partial(
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
)
self.cfg = cfg
self.max_length = max_length
self.nthreads = nthreads
self.transform_func = transform_func
self.prefetch_slice = prefetch_slice
self.slice_size = slice_size
self.abs_weight = cfg.get("abs_weight", None)
self.task_name = cfg["task_name"]
self.dataset_name = cfg["dataset_name"]
self.oversize_rule = cfg["oversize_rule"]
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", False))
self.exhausted = False
self.iterator = None
self.counter = 0
self.allow_repeat = cfg.get("allow_repeat", True)
self.used = set()
self.is_local = is_local
self.port_offset = random.randint(1, 8000) + 1000
if is_local or bmt.rank() == 0:
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
_avg_tokens = cfg.get("ave_tokens_per_line", -1)
_avg_tokens = cfg.get("avg_tokens", _avg_tokens)
requests.post(f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={_avg_tokens}")
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
def set_avg_tokens(self, avg_tokens):
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={avg_tokens}"
response = requests.post(url)
if response.status_code != 200:
self.reset_port_offset()
print(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
raise RuntimeError(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
def update_avg_tokens_by_ema(self, length):
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=update&length={length}"
response = requests.post(url)
if response.status_code != 200:
self.reset_port_offset()
print(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
raise RuntimeError(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
@property
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
def avg_tokens(self):
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}"
response = requests.get(url)
if response.status_code != 200:
self.reset_port_offset()
print(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
raise RuntimeError(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
data = response.json()
return data["avg_tokens"]
def reset_port_offset(self):
self.port_offset = random.randint(1, 8000) + 1000
def size(self):
return self.dataset.size()
def __iter__(self):
self.iterate()
return self
def reset(self):
self.exhausted = False
if self.iterator is not None:
self.iterator.close()
self.iterator = None
self.used = BitSet()
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
def transform(self, data: dict) -> dict:
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
weight = weight / weight.sum()
num_incontext = np.random.choice(weight.shape[0], p=weight)
return self.transform_func(data, num_incontext, random.Random())
def segment_iterate(self, sample_iter):
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
for src_ids, tgt_ids in self.segment(self.transform(data)):
self.update_avg_tokens_by_ema(len(src_ids)) # 0 for input ids
yield src_ids, tgt_ids, index
def iterate(self):
# make the dataset itself an iterator
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
self.iterator = self.segment_iterate(sample_iter)
def __next__(self):
# advance the task iterator
if self.iterator is None:
self.iterate()
try:
return next(self.iterator)
except StopIteration:
self.exhausted = True
return None
def load_state_dict(self, state_dict):
if state_dict.get("exhausted", False):
self.exhausted = True
self.used = BitSet()
else:
used = state_dict.get("used", BitSet())
if len(used) == len(self.dataset):
self.exhausted = True
self.used = BitSet()
else:
self.exhausted = False
self.used = used
_avg_tokens = state_dict.get("ave_tokens", -1)
_avg_tokens = state_dict.get("avg_tokens", _avg_tokens)
if self.avg_tokens == -1 or self.avg_tokens < _avg_tokens:
self.set_avg_tokens(_avg_tokens)
def state_dict(self):
if len(self.used) == len(self.dataset):
return dict(exhausted=True, used=BitSet(), avg_tokens=self.avg_tokens)
else:
return dict(exhausted=False, used=self.used, avg_tokens=self.avg_tokens)
def update_state(self, indice):
self.used.update(indice)
class MixedIndexedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg_path: str,
cfg_json_str,
tokenizer,
max_length,
weight_by_size: bool = True,
nthreads=5,
prefetch_slice=100,
parallel_loading=False,
vdc_sampling=False,
update_weights_frequency=1,
seed=42,
):
if bmt.rank() == 0:
port = int(os.environ["MASTER_PORT"]) + 2188
self.flask_ps_proc = mp.Process(target=flask_ps.run, kwargs={"host": "0.0.0.0", "port": port})
self.flask_ps_proc.start()
super(MixedIndexedDataset, self).__init__()
self.set_seed(seed + bmt.rank())
self.weight_by_size = weight_by_size
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.bos_token_id = self.tokenizer.bos_token_id
self.path2transform = dict()
self.task_dict = OrderedDict()
self.nthreads = nthreads
self.prefetch_slice = prefetch_slice
# useful for indexing
self.tasks = []
self.names = []
# ending of iteration
self.remain = 0
self.max_length = max_length
self.vdc_sampling = vdc_sampling
if self.vdc_sampling:
self._vdc_values = [van_der_corput(i) for i in range(100000)]
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
self.update_weights_frequency = update_weights_frequency
self.path2transform = dict()
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
if parallel_loading:
self.parallel_load(cfgs, max_workers=None)
else:
self.sequential_load(cfgs)
self.weights = None
self.update_weights()
def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_task(self, cfg):
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
return task
def sequential_load(self, cfgs):
self.cfgs = cfgs
for cfg in cfgs:
# python3.7 and later preserves insertion order to dictionary
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
self.task_dict[task.task_name] = task
self.tasks.append(task)
self.names.append(task.task_name)
self.remain += 1
self.weights = None
self.update_weights()
def load_state_dict(self, state_dict):
missing_keys = []
for name, task in self.task_dict.items():
if name in state_dict:
task.load_state_dict(state_dict[name])
else:
missing_keys.append(name)
self.update_weights()
return missing_keys
def save_state_dict(self, path):
state_dict = {}
for name, task in self.task_dict.items():
_state_dict = task.state_dict()
if isinstance(_state_dict["used"], BitSet):
bitset = _state_dict["used"]
_file_name = bitset.save(path)
_state_dict["used"] = _file_name
state_dict[name] = _state_dict
else:
state_dict[name] = task.state_dict()
torch.save(state_dict, path)
logger.info("Dataset state saved")
def update_states(self, task_ids, indice):
is_dict = isinstance(indice, dict)
uniq = torch.unique(task_ids)
for idx in uniq:
idx = idx.item()
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
self.tasks[idx].update_state(indexes)
def get_transform_func(self, module_name: str, transform_script_path):
if transform_script_path is None:
# allow null transform
return lambda data, num_incontext, rand: data
if "/" in module_name:
module_name = "fm9g_live.transforms.{}".format(module_name.split("/")[-1])
else:
module_name = "fm9g_live.transforms.{}".format(module_name)
if transform_script_path not in self.path2transform:
# loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
# spec = importlib.util.spec_from_loader(loader.name, loader)
spec = importlib.util.spec_from_file_location(module_name, transform_script_path)
if spec is None:
raise RuntimeError("Spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self.path2transform[transform_script_path] = {
"module": mod,
"last_mtime": 0,
}
transform_script_info = self.path2transform[transform_script_path]
curr_mtime = float(os.path.getmtime(transform_script_path))
if curr_mtime > transform_script_info["last_mtime"]:
transform_script_info["last_mtime"] = curr_mtime
# load module
spec.loader.exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
return transform_func
def update_weights(self):
task0 = self.tasks[0]
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
weights = []
for task in self.tasks:
if task.exhausted:
weights.append(0)
else:
if task.avg_tokens == -1:
weights.append(task.abs_weight / self.max_length)
else:
weights.append(task.abs_weight / task.avg_tokens)
weights = np.array(weights)
else:
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
if self.weight_by_size:
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
weights *= sizes
self.weights = weights / weights.sum()
def __iter__(self):
for task in self.tasks:
task.iterate()
return self
def __next__(self):
step = 1
while True:
if self.remain == 0:
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
raise StopIteration
if self.vdc_sampling:
idx = next(self.vdc_gen)(self.weights)
else:
idx = np.random.choice(len(self.weights), p=self.weights)
data = next(self.tasks[idx])
if step % self.update_weights_frequency == 0:
self.update_weights()
if data is None:
if self.tasks[idx].allow_repeat:
# _runtime_ave = self.tasks[idx].avg_tokens
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
# self.tasks[idx] = SegmentedDataset(
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
# )
# self.tasks[idx].avg_tokens_update(_runtime_ave)
self.tasks[idx].reset()
else:
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
self.tasks[idx].exhaust = True
self.remain -= 1
continue
step += 1
return dict(
task_id=idx,
input=data[0],
target=data[1],
index=data[2],
is_long=self.tasks[idx].cfg.get("is_long", False),
)
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
self.max_total_length = batch_size * max_length
self.batch_size = 1
# setting compact=True concats segments orignated from the same input
# into a long sequence. the relative order of segments should be preserved
# in mixed_dataset, e.g.,
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
self.compact = compact
self.total_length = 0
self.task2seqs = defaultdict(list)
self.mixed_dataset = mixed_dataset
self._max_length = max_length
self._pose_prob = pose_prob
self._pose_scaling_factor = pose_scaling_factor
if self._pose_prob > 0.0:
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
else:
self._scaled_max_length = max_length
def put(self, sample):
self.total_length += len(sample["target"])
task_id = sample["task_id"]
if self.compact and self.task2seqs[task_id]:
last = self.task2seqs[task_id][-1]
if last["target"][-1] != self.mixed_dataset.eos_token_id:
# concatenate sequantial segments for longer context modeling: why not?
last["input"].extend(sample["input"])
last["target"].extend(sample["target"])
return
self.task2seqs[task_id].append(sample)
def _pose_preprocess(
self,
input_ids: NDArray[np.int32],
) -> NDArray[np.int32]:
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
"""
len_chunk = min(len(input_ids), self._max_length)
len_input = len(input_ids)
# Chunk input randomly to fit max_length if needed
lt1 = 0
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
# Generate PoSE position ids
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
len_position_ids = len(position_ids)
lt = 0
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
position_ids[: rt1 - lt1] += lt
position_ids[rt1 - lt1 :] += rt
return position_ids
def pop(self):
indexes = defaultdict(list)
lengths = []
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
span_begin = 0
for samples in self.task2seqs.values():
while samples:
sample = samples.pop()
span_end = span_begin + len(sample["input"])
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
_span_position_ids = self._pose_preprocess(sample["input"])
else:
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
lengths.append(len(sample["target"]))
indexes[int(sample["task_id"])].append(sample["index"])
self.total_length -= len(sample["target"])
span_begin = span_end
cu_seqlens = torch.cat(
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
dim=0,
).int()
batch = {
"inputs": inputs,
"targets": targets,
"task_ids": task_ids,
"indexes": indexes,
# adhere to flash attention interface
"cu_seqlens": cu_seqlens,
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
"lengths": torch.tensor(sum(lengths)).int(),
"task_names": self.mixed_dataset.names,
"position_ids": position_ids,
}
return batch
def will_be_full(self, sample):
return self.total_length + len(sample["target"]) > self.max_total_length
def __iter__(self):
for sample in self.mixed_dataset:
if self.will_be_full(sample):
yield self.pop()
self.put(sample)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader, tp_size=1, tp_rank=0):
self.loader = iter(loader)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
if self.tp_size > 1:
if self.tp_rank == 0:
data = next(self.loader)
print("Rank {}, Preload data done.".format(bmt.rank()))
d = {}
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
for key in data.keys():
if isinstance(data[key], torch.Tensor):
np_cur_data = data[key].cpu().numpy()
bs = np_cur_data.tobytes()
fb.write(bs)
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
elif isinstance(data[key], np.ndarray):
bs = data[key].tobytes()
fb.write(bs)
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
else:
d[key] = data[key]
try:
_ = json.dumps(d)
except TypeError:
print(d)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
json.dump(d, f)
bmt.synchronize()
if self.tp_rank != 0:
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
data = json.load(f)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
bs = fb.read()
offset = 0
for key in data.keys():
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
nw_offset = offset + data[key][2]
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
data[key][3:]
)
offset = nw_offset
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
nw_offset = offset + data[key][2]
data[key] = torch.from_numpy(
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
.reshape(data[key][3:])
.copy()
)
offset = nw_offset
self.data = data
else:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
data = copy.deepcopy(self.data)
self.preload()
return data
def __iter__(self):
return self

View File

@ -12,3 +12,4 @@ from .position_embedding import RotaryEmbedding
from .position_embedding import RotaryEmbeddingESM
from .position_embedding import SegmentPositionEmbedding
from .transformer import Encoder
#from _attention_pp_sp import OpAttnPipeSP

View File

@ -0,0 +1,79 @@
import torch
import torch.nn.functional as F
import bmtrain as bmt
def _linear_backward(grad_output, x, weight, bias):
grad_x = grad_weight = grad_bias = None
if x.requires_grad:
grad_x = grad_output.matmul(weight)
if weight.requires_grad:
grad_weight = grad_output.reshape(-1,
grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
if bias is not None and bias.requires_grad:
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
return grad_x, grad_weight, grad_bias
class OpAttnPipeSP(torch.autograd.Function):
@staticmethod
def forward(ctx, q_w, k_w, v_w, q_b, w_b, v_b, x, cache_kv, cache_kv_inp, cu_seqlens_q, cu_seqlens_k, max_seqlen):
ctx.save_for_backward(x, q_w, k_w, v_w, q_b, w_b, v_b)
if cache_kv.numel() = 0:
q = F.linear(x, q_w, q_b)
k = F.linear(x, k_w, w_b)
v = F.linear(x, v_w, v_b)
else:
q = F.linear(x, q_w, q_b)
k = F.linear(x, k_w, w_b)
v = F.linear(x, v_w, v_b)
k = torch.cat([cache_kv[0], k], dim=1)
v = torch.cat([cache_kv[1], v], dim=1)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen, max_seqlen, 0, causal=True, window_size=(-1,-1), alibi_slopes=None, deterministic=False, return_attn_probs=False
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.max_seqlen_q = max_seqlen
ctx.max_seqlen_k = max_seqlen
return F.linear(x, weight, bias)
@staticmethod
def backward(ctx, grad_output):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
False,
(-1,-1),
None,
False,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
d_xk, d_wk, d_bk = _linear_backward(dk, x, k_w, k_b)
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None

View File

@ -13,13 +13,13 @@ from einops import rearrange
from .linear import ColumnParallelLinear
from .linear import Linear
from .position_embedding import apply_chatglm_rotary_pos_emb
try:
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except:
flash_attn_varlen_func = None
from flash_attn.flash_attn_interface import flash_attn_varlen_func
#try:
# from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
# from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
# from flash_attn.flash_attn_interface import flash_attn_varlen_func
#except:
# flash_attn_varlen_func = None
try:
from flash_attn.bert_padding import pad_input
@ -54,6 +54,8 @@ class OpFlash(torch.autograd.Function):
dropout_p,
ctx.softmax_scale,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
return_softmax=False,
)
if record:
@ -85,6 +87,9 @@ class OpFlash(torch.autograd.Function):
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
(-1,-1),
None,
False,
rng_state=rng_state,
)
return None, None, dq, dk, dv, None, None, None, None
@ -294,12 +299,28 @@ class Attention(bmt.DistributedModule):
h_q, h_k = position_bias(
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
)
# score = flash_attn_varlen_func(
# h_q, h_k, h_v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, self.dropout_p, causal=True
# )
score = OpFlash.apply(
self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
print("h_q: ", h_q)
print("cu_seqlens: ", cu_seqlens)
print("max_seqlen: ", max_seqlen)
score = flash_attn_varlen_func(
h_q,
h_k,
h_v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
self.dropout_p,
causal=True,
deterministic=True,
)
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
# Rongqiao change
#score = OpFlash.apply(
# self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
#)
score = score.view(batch_size, len_q, -1)

22
FM_9G/fm9g/metrics/ema.py Normal file
View File

@ -0,0 +1,22 @@
from typing import Optional
class EMAValue(object):
def __init__(self, init_value: Optional[float] = None, decay_factor: float = 0.999) -> None:
super().__init__()
self._value = init_value
self._decay_factor = decay_factor
@property
def value(self) -> Optional[float]:
return self._value
def update(self, value: float) -> None:
if self._value is None:
self._value = value
else:
self._value = self._decay_factor * self._value + (1 - self._decay_factor) * value
def update_with_return(self, value: float) -> Optional[float]:
self.update(value)
return self._value

View File

@ -0,0 +1 @@
from .fm9g import FM9GTokenizer

Some files were not shown because too many files have changed in this diff Show More