Compare commits

...

7 Commits

Author SHA1 Message Date
carboncoo c89395164e 1119 2024-11-19 11:01:06 +08:00
carboncoo 8e693d5876 1119 2024-11-19 11:00:26 +08:00
carboncoo 415c624322 Read me 2024-11-19 10:48:12 +08:00
carboncoo 4139ba5dfe 1119 2024-11-19 10:44:01 +08:00
carboncoo a8d431c14f 1119 2024-11-19 10:42:49 +08:00
carboncoo 1857f60d1e 1118 2024-11-18 17:59:15 +08:00
carboncoo a041469104 1118 2024-11-18 17:47:43 +08:00
191 changed files with 2704 additions and 354074 deletions

View File

@ -1,4 +0,0 @@
# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright @2024, QiYuan Inc

View File

@ -1,20 +0,0 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = "<用户>"+data['input']+"<AI>"
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -1,20 +0,0 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = data['input']
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -1,134 +0,0 @@
[
{
"dataset_name": "humanevallike_clean_dedup",
"task_name": "humanevallike_clean_dedup",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 995339,
"ave_tokens_per_line": 100,
"total_tokens": 0.1
},
{
"dataset_name": "leetcode_pass_code_0125",
"task_name": "leetcode_pass_code_0125",
"abs_weight": 0.006,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 10724,
"ave_tokens_per_line": 200,
"total_tokens": 0.002
},
{
"dataset_name": "logiv2Annotate",
"task_name": "logiv2Annotate",
"abs_weight": 0.004,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 12566,
"ave_tokens_per_line": 512,
"total_tokens": 0.006
},
{
"dataset_name": "mmlu_enhance",
"task_name": "mmlu_enhance",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 169771,
"ave_tokens_per_line": 300,
"total_tokens": 0.05
},
{
"dataset_name": "mtbench_like",
"task_name": "mtbench_like",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 319080,
"ave_tokens_per_line": 500,
"total_tokens": 0.15
},
{
"dataset_name": "ultra_dataset_new",
"task_name": "ultra_dataset_new",
"abs_weight": 2.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 385045,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 2.0
},
{
"dataset_name": "sft_data_zh_wowru",
"task_name": "sft_data_zh_wowru",
"abs_weight": 1.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 1
},
{
"dataset_name": "math_data",
"task_name": "math_data",
"abs_weight": 0.003,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 0.005
},
{
"dataset_name": "t0",
"task_name": "t0",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 1650309,
"ave_tokens_per_line": 500.296266559615,
"total_tokens": 0.82
},
{
"dataset_name": "wikihow",
"task_name": "wikihow",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 180128,
"ave_tokens_per_line": 900.296266559615,
"total_tokens": 0.16
},
{
"dataset_name": "reclor",
"task_name": "reclor",
"abs_weight": 0.002,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 4174,
"ave_tokens_per_line": 700.296266559615,
"total_tokens": 0.003
},
{
"dataset_name": "logic_test_lx_0127",
"task_name": "logic_test_lx_0127",
"abs_weight": 0.001,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2800,
"ave_tokens_per_line": 200.96266559615,
"total_tokens": 0.0004
}
]

View File

@ -1,28 +0,0 @@
{
"vocab_size": 122753,
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"half_type": "bf16",
"use_flash_attn": true,
"flash_attn_mask_shape": "2d",
"dim_model": 2304,
"dim_ff": 5760,
"dim_head": 64,
"num_heads": 36,
"num_kv_heads": 36,
"num_layers": 40,
"activate_fn": "silu",
"init_std": 0.10,
"scale": true,
"scale_emb": 12,
"scale_depth": 1.4,
"dim_model_base": 256,
"model_type": "fm9g",
"architectures": [
"FM9GForCausalLM"
],
"qk_norm": false,
"tie_lm_head": true,
"ffn_gated": true
}

View File

@ -1,548 +0,0 @@
# coding=utf-8
# Copyright 2024 QiYuan Inc.
import inspect
import json
import math
import os
import re
import sys
import time
from collections import defaultdict
from itertools import chain
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from bmtrain import nccl
from bmtrain.global_var import config as bmt_config
sys.path.append("../../")
from fm9g.arguments import get_args
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
from fm9g.utils import exporter
from fm9g.utils import logger
from fm9g.utils.exporter import save_every_step_stats
from fm9g.utils.training_stats import num_non_embedding_parameters
from fm9g.utils.training_stats import num_parameters
def get_tokenizer(args):
from transformers import LlamaTokenizerFast
tokenizer = LlamaTokenizerFast(vocab_file=args.tokenizer_path)
return tokenizer
def get_model(args):
config = DragonflyConfig.from_json_file(args.model_config)
config.tp = 1 if args.tp_size != 1 else 0 # TODO
config.pose_prob = args.pose_prob
config.pose_scaling_factor = args.pose_scaling_factor
config.rope_scaling_type = args.rope_scaling_type
config.rope_scaling_factor = args.rope_scaling_factor
config.orig_max_length = args.orig_max_length
bmt.print_rank("model config: {}".format(config))
bmt.print_rank("bmt config: {}".format(bmt.config))
model = Dragonfly(config)
if args.load is not None:
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
exporter.load_model_ckpt(args, model)
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
return model
def get_optimizer(args, model):
scale_lr_group = []
normal_group = []
scale_lr_group_name, normal_group_name = [], []
for n, p in model.named_parameters():
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
scale_lr_group.append(p)
scale_lr_group_name.append(n)
else:
normal_group.append(p)
normal_group_name.append(n)
bmt.print_rank(scale_lr_group_name, normal_group_name)
param_groups = [
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
{"params": normal_group, "lr": args.lr},
]
if args.offload:
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
else:
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
if args.load is not None and args.load_grad:
exporter.load_optimizer_ckpt(args, optimizer)
bmt.print_rank("optimizer is loaded!")
return optimizer
def get_learning_rate_scheduler(args, optimizer):
from fm9g.training_utils.lr_scheduler import Cosine
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
end_iter = args.train_iters
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
warmup_iters = int(end_iter * args.warmup_iters)
else:
warmup_iters = int(args.warmup_iters)
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
drop_iters = int(end_iter * args.drop_iters)
else:
drop_iters = int(args.drop_iters)
if args.lr_scheduler == "cosine":
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
num_iter=args.start_step,
#lr_end_restart=args.lr_end_restart,
#resume_no_optimze=args.resume_no_optimze,
)
elif args.lr_scheduler == "warmupstabledrop":
lr_scheduler = WarmupStableDrop(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
drop_iter=drop_iters,
num_iter=args.start_step,
resume_no_optimze=args.resume_no_optimze,
)
return lr_scheduler
def setup_model_and_optimizer(args):
start = time.time()
tokenizer = get_tokenizer(args)
bmt.synchronize()
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
start = time.time()
model = get_model(args)
logger.info("load model in {:.2f}s".format(time.time() - start))
start = time.time()
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
return tokenizer, model, optimizer, lr_scheduler
def resume_training(args):
ckpts = sorted(
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
reverse=True,
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
)
# find newest job
ckpts = sorted(
ckpts,
reverse=True,
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
)
if len(ckpts) > 0:
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
args.load = ckpts[0]
# by default, do not load grad file
args.load_grad = False
args.start_step = 0
else:
# no ckpts, nothing we can do
os._exit(1)
def initialize():
args = get_args(pretrain=True)
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
if args.load is not None:
if args.only_load_model == 0:
if args.start_step == 0:
log_ckpt = exporter.load_log_ckpt(args)
if "iteration" in log_ckpt:
args.start_step = log_ckpt["iteration"]
else:
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
logger.info("Start from step {}".format(args.start_step))
elif args.only_load_model == 1:
logger.info("You load model ckpt, and choose to completely start the 0 step.")
else:
raise NotImplementedError
else:
logger.info("You do not load model")
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
bmt.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
def get_task_loss_and_token(loss, task_ids, task_num, targets):
# task_ids 可能有-1 来代表无效token
_task_num = task_num + 1
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
# gen masks
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
_loss_mask = torch.ne(targets, -100).to(torch.int32)
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
# calc loss and tokens
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
# return token-wise avg losses and tokens
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
class ChunkAve:
def __init__(self, chunk_size=100):
self.ave_list = []
self.chunk_size = chunk_size
def record(self, time):
self.ave_list.append(time)
self.ave_list = self.ave_list[-self.chunk_size :]
def get(self):
return sum(self.ave_list) / len(self.ave_list)
def pretrain(
args,
tokenizer,
model: Dragonfly,
optimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
):
ave_model_time = ChunkAve(chunk_size=100)
ave_iter_time = ChunkAve(chunk_size=100)
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
optim_manager = bmt.optim.OptimManager(
loss_scale=None,
loss_scale_steps=args.loss_scale_steps,
loss_scale_factor=2,
max_loss_scale=args.max_loss_scale,
min_loss_scale=args.min_loss_scale,
)
optim_manager.add_optimizer(optimizer, lr_scheduler)
start_step = args.start_step
if args.tensorboard is not None and bmt.rank() == 0:
import distutils.version # noqa: F401
from tensorboardX import SummaryWriter
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
if args.load is not None:
log_ckpt = exporter.load_log_ckpt(args)
else:
log_ckpt = {}
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
global_world_size = bmt.world_size()
bmt.print_rank("Begin preparing dataset")
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset = MixedIndexedDataset(
cfg_path=args.dataset,
cfg_json_str=None,
tokenizer=tokenizer,
max_length=args.max_length,
nthreads=args.dataloader_num_threads,
prefetch_slice=args.dataloader_prefetch,
weight_by_size=True,
)
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
dataloader = torch.utils.data.DataLoader(
batched_dataset,
batch_size=None,
collate_fn=lambda x: x,
num_workers=args.dataloader_num_workers,
prefetch_factor=args.dataloader_prefetch_factor,
)
else:
def dummy_generator():
while True:
yield None
mixed_indexed_dataset = dummy_generator()
dataloader = mixed_indexed_dataset
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
bmt.print_rank("Preparing dataset done.")
# inspect at init
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
try:
mem_usage, tim_usage = {}, {}
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
for iteration, data in enumerate(DataIterator, start=start_step + 1):
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
logits = model(
input=data["inputs"],
cu_seqlens=data["cu_seqlens"],
max_seqlen=data["max_seqlen"],
position_ids=data["position_ids"],
)
# chunk targets and task_ids
data["targets"] = (
data["targets"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["targets"].shape[0], -1)
)
data["task_ids"] = (
data["task_ids"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["task_ids"].shape[0], -1)
)
_target = data["targets"].view(-1)
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
_w = (_target != -100).int()
loss = non_reduced_loss.sum() / _w.sum().float()
global_loss = bmt.sum_loss(loss).item()
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
grad_accum_init_time = tim_usage["init"]
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
optim_manager.zero_grad()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
model_time = tim_usage["optim"] - grad_accum_init_time
ave_model_time.record(model_time)
else:
# dummy optim step
grad_norm = torch.Tensor([0.0]).cuda()
tim_usage["optim"] = tim_usage["backward"]
mem_usage["optim"] = mem_usage["backward"]
with torch.no_grad():
task_num = len(data["task_names"])
task_loss, task_token = get_task_loss_and_token(
non_reduced_loss, data["task_ids"], task_num, data["targets"]
)
task_loss_map: Dict[str, float] = {}
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
tot_task_token = gatherd_task_token_map.sum(dim=0)
ave_task_loss = sum_task_loss / tot_task_token
for i in range(task_num):
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
local_total_rate = torch.Tensor([data["lengths"].float().mean() / args.max_length]).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += (
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
)
bmt.print_rank(
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
)
last_before_log_time = tim_usage["before_log"]
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
iter_time = tim_usage["before_log"] - last_before_log_time
ave_iter_time.record(iter_time)
train_info = {
"time": iter_time,
"iteration": iteration,
"loss": global_loss,
"lr": lr_scheduler.current_lr,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
"grad_norm": grad_norm.item(),
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
"task_loss": task_loss_map,
"total_task_token": global_total_task_token,
}
global_token_pass_str = convert_to_k_and_b(global_token_pass)
bmt.print_rank(
(
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time:.2f} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
+ "{global_token_pass} | mem_usage {mem_usage} | "
).format(
iteration=iteration,
loss=global_loss,
lr=lr_scheduler.current_lr,
model_time=model_time,
iter_time=iter_time,
chunk_ave_time=ave_iter_time.get(),
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
grad_norm=grad_norm.item(),
global_token_pass=global_token_pass_str,
mem_usage=max([value for key, value in mem_usage.items()]),
)
)
bmt.print_rank(
"task_loss:\t| "
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
+ " |"
)
if iteration % 10 == 0:
bmt.print_rank(
"task_tokens (B):\t| "
+ " | ".join(
[
"{}: {:.4f}".format(task_name, task_token / 10**9)
for task_name, task_token in global_total_task_token.items()
]
)
+ " |"
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
if args.log_dir is not None and bmt.rank() == 0:
if args.save is not None:
save_every_step_stats(train_info, args.save)
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar("Loss/train", global_loss, iteration)
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
for task_name, loss in task_loss_map.items():
if not math.isnan(loss):
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
log_ckpt = {
"global_total_task_token": global_total_task_token,
"global_token_pass": global_token_pass,
"iteration": iteration,
}
if args.save is not None and iteration % args.save_iters == 0:
exporter.export(
model,
mixed_indexed_dataset,
tokenizer,
optimizer,
iteration,
args,
log_ckpt=log_ckpt,
final_save=False,
)
if iteration == args.train_iters and args.stop_when_end == 1:
break
except Exception as e:
print(f"train loop err: {e}")
raise e
finally:
pass
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
def convert_to_k_and_b(number):
if number >= 1e9: # 大于或等于10亿
b_number = number / 1e9
return f"{b_number:.2f}B"
elif number >= 1e6: # 大于或等于1百万
k_number = number / 1e6
return f"{k_number:.2f}M"
elif number >= 1e3:
k_number = number / 1e3
return f"{k_number:.2f}K"
else:
return str(number)
def main():
args = initialize()
bmt.synchronize()
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
bmt.print_rank("finish loading")
bmt.print_rank(
"Number of parameter {}, Number of non-e parameter {}".format(
num_parameters(model), num_non_embedding_parameters(model)
)
)
bmt.print_rank("args: {}".format(args))
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
if __name__ == "__main__":
main()

View File

@ -1,234 +0,0 @@
#!/bin/bash
#export OMP_NUM_THREADS=16
declare -A args # Declare an associative array to store arguments and values
args["model_unique"]="2b_0701"
args["resume_ckpt"]=""
args["config"]="2.4b"
args["flash"]="cuda"
args["batch_size"]="1"
args["max_length"]="4096"
args["save_iters"]="500"
args["train_iters"]="10"
args["dataset_config"]="fm9g_sft"
args["local"]="False"
args["dataloader"]="indexed"
args["save"]="True"
args["dataloader_num_threads"]=1
args["dataloader_prefetch"]=1
args["dataloader_prefetch_factor"]=1
args["dataloader_num_workers"]=1
args["lr"]="1e-5"
args["warmup_iters"]="20"
args["drop_iters"]="0.1"
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
args["load_grad"]="False"
args["grad_ckpt_num"]="160"
args["exp_group"]=""
args["ignore_cuda_oom"]="1"
args["tensorboard_all_tasks"]="0"
args["stop_when_end"]="0"
args["only_run_dataloader"]="0"
args["eps"]="1e-6"
args["inspect_iters"]="100"
args["strict_state_dict"]="1"
args["only_load_model"]="1"
args["lr_scheduler"]="cosine"
args["resume_no_optimze"]="0"
args["tp_size"]="1"
args["parallel_load_datastate"]="8"
args["async_save"]="False"
args["load_dataloader_ckpt"]="0"
args["drop_begin"]="-1"
args["drop_rate"]="0.5"
args["use_checkpoint"]="0"
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# 使用 Python 读取 JSON 文件并更新 Bash 字典
while read -r key value; do
args["$key"]="$value"
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
# 用cmd arg 再更新一次
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# Print the values of the arguments
echo "----------- CMD args ----------"
for key in "${!args[@]}"; do
echo "$key: ${args[$key]}"
done
echo "--------- END CMD args --------"
if [[ ${args["flash"]} == "triton" ]]; then
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
echo "triton flash"
fi
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
# GPUS_PER_NODE=1
echo "Using ${GPUS_PER_NODE} GPU each machine"
if [[ ${args["model_unique"]} == "" ]]; then
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
else
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
fi
echo "model_unique: "$MODEL_UNIQUE
# --------------- 运行参数 ---------------
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
OPTS+=" --batch-size ${args["batch_size"]}"
OPTS+=" --train-iters ${args["train_iters"]}"
OPTS+=" --save-iters ${args["save_iters"]}"
OPTS+=" --save-name fm9g_live_checkpoint"
OPTS+=" --max-length ${args["max_length"]}"
OPTS+=" --lr ${args["lr"]}"
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
OPTS+=" --drop-iters ${args["drop_iters"]}"
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
OPTS+=" --offload"
#OPTS+=" --vocab ./tokenizer/vocab.txt"
OPTS+=" --flash ${args["flash"]}"
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
OPTS+=" --eps ${args["eps"]}"
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
OPTS+=" --only_load_model ${args["only_load_model"]}"
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
OPTS+=" --weight-decay 0.1"
OPTS+=" --tp-size ${args["tp_size"]}"
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
OPTS+=" --drop_begin ${args["drop_begin"]}"
OPTS+=" --drop_rate ${args["drop_rate"]}"
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
if [[ ${args["load_grad"]} == "True" ]]; then
OPTS+=" --load-grad"
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
fi
if [[ ${args["async_save"]} == "True" ]]; then
OPTS+=" --async_save"
fi
if [[ ${args["dataloader"]} == "indexed" ]]; then
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
fi
# --------------- 写文件路径 ---------------
## checkpoint
if [[ ${args["save"]} == "True" ]]; then
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
else
echo "won't save model"
fi
## logs/local/logs 等价于 ./datalogs软链
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
if [[ ${args["local"]} == "True" ]]; then
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
else
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
echo "Platform config:"${PLATFORM_CONFIG_PATH}
fi
## checkpoint兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint启动会比较快
if [ "${args["resume_ckpt"]}" != "" ]; then
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
else
echo "No checkpoint to load"
fi
filename="pretrain_dragonfly"
if [[ ${args["local"]} == "True" ]]; then
PRETRAIN_ENTRY="$filename.py"
else
PRETRAIN_ENTRY="$filename.py"
fi
GPUS_PER_NODE=1
NNODES=1
RANK=0
MASTER_ENDPOINT=g3006
MASTER_PORT=23456
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
$CMD

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +0,0 @@
{
"pretrain": {
"train_iters": 1000000000,
"batch_size": 1,
"max_length": 4096,
"n_gpus": 8,
"lr": 0.01
}
}

View File

@ -1,20 +0,0 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = "<用户>"+data['input']+"<AI>"
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -1,20 +0,0 @@
import random
def rand(n: int, r: random.Random):
return int(r.random() * n)
def transform(data, num_sample: int, r: random.Random):
if 'input' in data:
_input = data['input']
else:
_input = ""
if 'output' in data:
_output = data['output']
else:
_output = ""
return {"input": _input,
"output": _output,
}

View File

@ -1,134 +0,0 @@
[
{
"dataset_name": "humanevallike_clean_dedup",
"task_name": "humanevallike_clean_dedup",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 995339,
"ave_tokens_per_line": 100,
"total_tokens": 0.1
},
{
"dataset_name": "leetcode_pass_code_0125",
"task_name": "leetcode_pass_code_0125",
"abs_weight": 0.006,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 10724,
"ave_tokens_per_line": 200,
"total_tokens": 0.002
},
{
"dataset_name": "logiv2Annotate",
"task_name": "logiv2Annotate",
"abs_weight": 0.004,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 12566,
"ave_tokens_per_line": 512,
"total_tokens": 0.006
},
{
"dataset_name": "mmlu_enhance",
"task_name": "mmlu_enhance",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 169771,
"ave_tokens_per_line": 300,
"total_tokens": 0.05
},
{
"dataset_name": "mtbench_like",
"task_name": "mtbench_like",
"abs_weight": 0.2,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 319080,
"ave_tokens_per_line": 500,
"total_tokens": 0.15
},
{
"dataset_name": "ultra_dataset_new",
"task_name": "ultra_dataset_new",
"abs_weight": 2.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 385045,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 2.0
},
{
"dataset_name": "sft_data_zh_wowru",
"task_name": "sft_data_zh_wowru",
"abs_weight": 1.0,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 1
},
{
"dataset_name": "math_data",
"task_name": "math_data",
"abs_weight": 0.003,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2963260,
"ave_tokens_per_line": 200.296266559615,
"total_tokens": 0.005
},
{
"dataset_name": "t0",
"task_name": "t0",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 1650309,
"ave_tokens_per_line": 500.296266559615,
"total_tokens": 0.82
},
{
"dataset_name": "wikihow",
"task_name": "wikihow",
"abs_weight": 0.1,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 180128,
"ave_tokens_per_line": 900.296266559615,
"total_tokens": 0.16
},
{
"dataset_name": "reclor",
"task_name": "reclor",
"abs_weight": 0.002,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
"transforms": "0124_hq_data/general/script_cpmc.py",
"allow_repeat": true,
"nlines": 4174,
"ave_tokens_per_line": 700.296266559615,
"total_tokens": 0.003
},
{
"dataset_name": "logic_test_lx_0127",
"task_name": "logic_test_lx_0127",
"abs_weight": 0.001,
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
"allow_repeat": true,
"nlines": 2800,
"ave_tokens_per_line": 200.96266559615,
"total_tokens": 0.0004
}
]

View File

@ -1,568 +0,0 @@
# coding=utf-8
# Copyright 2022 ModelBest Inc.
import inspect
import json
import math
import os
import re
import sys
import time
from collections import defaultdict
from itertools import chain
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from bmtrain import nccl
from bmtrain.global_var import config as bmt_config
sys.path.append("../../")
from fm9g.arguments import get_args
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import CudaPrefetcher
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import MixedIndexedDataset
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import UnpadBatchedMixedDataset
from fm9g.utils import exporter
from fm9g.utils import logger
from fm9g.utils.exporter import save_every_step_stats
from fm9g.utils.training_stats import num_non_embedding_parameters
from fm9g.utils.training_stats import num_parameters
def get_tokenizer(args):
from fm9g.tokenizer import FM9GTokenizer
tokenizer = FM9GTokenizer(path=args.vocab)
return tokenizer
def get_model(args):
config = DragonflyConfig.from_json_file(args.model_config)
config.tp = 1 if args.tp_size != 1 else 0 # TODO
config.pose_prob = args.pose_prob
config.pose_scaling_factor = args.pose_scaling_factor
config.rope_scaling_type = args.rope_scaling_type
config.rope_scaling_factor = args.rope_scaling_factor
config.orig_max_length = args.orig_max_length
config.use_checkpoint = True if args.use_checkpoint == 1 else False
bmt.print_rank("model config: {}".format(config))
bmt.print_rank("bmt config: {}".format(bmt.config))
model = Dragonfly(config)
if args.load is not None:
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
exporter.load_model_ckpt(args, model)
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
return model
def get_optimizer(args, model):
scale_lr_group = []
normal_group = []
scale_lr_group_name, normal_group_name = [], []
for n, p in model.named_parameters():
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
scale_lr_group.append(p)
scale_lr_group_name.append(n)
else:
normal_group.append(p)
normal_group_name.append(n)
bmt.print_rank(scale_lr_group_name, normal_group_name)
param_groups = [
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
{"params": normal_group, "lr": args.lr},
]
if args.offload:
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
else:
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
if args.load is not None and args.load_grad:
exporter.load_optimizer_ckpt(args, optimizer)
bmt.print_rank("optimizer is loaded!")
return optimizer
def get_learning_rate_scheduler(args, optimizer):
from fm9g.training_utils.lr_scheduler import Cosine
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
from fm9g.training_utils.lr_scheduler import WarmupStableExp
end_iter = args.train_iters
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
warmup_iters = int(end_iter * args.warmup_iters)
else:
warmup_iters = int(args.warmup_iters)
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
drop_iters = int(end_iter * args.drop_iters)
else:
drop_iters = int(args.drop_iters)
if args.lr_scheduler == "cosine":
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
num_iter=args.start_step,
)
# lr_end_restart=args.lr_end_restart,
# resume_no_optimze=args.resume_no_optimze,
#)
elif args.lr_scheduler == "warmupstabledrop":
lr_scheduler = WarmupStableDrop(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
end_iter=end_iter, # 原来是lr_decay_iter
drop_iter=drop_iters,
num_iter=args.start_step,
resume_no_optimze=args.resume_no_optimze,
)
elif args.lr_scheduler == "warmupstableexp":
lr_scheduler = WarmupStableExp(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iters,
drop_begin=args.drop_begin, # 原来是lr_decay_iter
drop_iter=drop_iters,
drop_rate=args.drop_rate,
num_iter=args.start_step,
resume_no_optimze=args.resume_no_optimze,
)
return lr_scheduler
def setup_model_and_optimizer(args):
start = time.time()
tokenizer = get_tokenizer(args)
bmt.synchronize()
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
start = time.time()
model = get_model(args)
logger.info("load model in {:.2f}s".format(time.time() - start))
start = time.time()
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
return tokenizer, model, optimizer, lr_scheduler
def resume_training(args):
ckpts = sorted(
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
reverse=True,
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
)
# find newest job
ckpts = sorted(
ckpts,
reverse=True,
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
)
if len(ckpts) > 0:
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
args.load = ckpts[0]
# by default, do not load grad file
args.load_grad = False
args.start_step = 0
else:
# no ckpts, nothing we can do
os._exit(1)
def initialize():
args = get_args(pretrain=True)
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
if args.load is not None:
if args.only_load_model == 0:
if args.start_step == 0:
log_ckpt = exporter.load_log_ckpt(args)
if "iteration" in log_ckpt:
args.start_step = log_ckpt["iteration"]
else:
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
logger.info("Start from step {}".format(args.start_step))
elif args.only_load_model == 1:
logger.info("You load model ckpt, and choose to completely start the 0 step.")
else:
raise NotImplementedError
else:
logger.info("You do not load model")
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
bmt.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
def get_task_loss_and_token(loss, task_ids, task_num, targets):
# task_ids 可能有-1 来代表无效token
_task_num = task_num + 1
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
# gen masks
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
_loss_mask = torch.ne(targets, -100).to(torch.int32)
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
# calc loss and tokens
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
# return token-wise avg losses and tokens
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
class ChunkAve:
def __init__(self, chunk_size=100):
self.ave_list = []
self.chunk_size = chunk_size
def record(self, time):
self.ave_list.append(time)
self.ave_list = self.ave_list[-self.chunk_size :]
def get(self):
return sum(self.ave_list) / len(self.ave_list)
def pretrain(
args,
tokenizer,
model: Dragonfly,
optimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
):
ave_model_time = ChunkAve(chunk_size=100)
ave_iter_time = ChunkAve(chunk_size=100)
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
optim_manager = bmt.optim.OptimManager(
loss_scale=bmt.world_size(),
loss_scale_steps=args.loss_scale_steps,
loss_scale_factor=2,
max_loss_scale=bmt.world_size(),
min_loss_scale=bmt.world_size(),
)
optim_manager.add_optimizer(optimizer, lr_scheduler)
start_step = args.start_step
if args.tensorboard is not None and bmt.rank() == 0:
import distutils.version # noqa: F401
from tensorboardX import SummaryWriter
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
if args.load is not None:
log_ckpt = exporter.load_log_ckpt(args)
else:
log_ckpt = {}
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
global_world_size = bmt.world_size()
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset = MixedIndexedDataset(
cfg_path=args.dataset,
cfg_json_str=None,
tokenizer=tokenizer,
max_length=args.max_length,
nthreads=args.dataloader_num_threads,
prefetch_slice=args.dataloader_prefetch,
weight_by_size=True,
)
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
dataloader = torch.utils.data.DataLoader(
batched_dataset,
batch_size=None,
collate_fn=lambda x: x,
num_workers=args.dataloader_num_workers,
prefetch_factor=args.dataloader_prefetch_factor,
)
else:
def dummy_generator():
while True:
yield None
mixed_indexed_dataset = dummy_generator()
dataloader = mixed_indexed_dataset
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
bmt.print_rank("Preparing dataset done.")
# inspect at init
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
try:
mem_usage, tim_usage = {}, {}
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
for iteration, data in enumerate(DataIterator, start=start_step + 1):
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
logits = model(
input=data["inputs"],
cu_seqlens=data["cu_seqlens"],
max_seqlen=data["max_seqlen"],
position_ids=data["position_ids"],
)
#print("logits: ", logits)
# chunk targets and task_ids
data["targets"] = (
data["targets"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["targets"].shape[0], -1)
)
data["task_ids"] = (
data["task_ids"]
.view(-1)
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
.view(data["task_ids"].shape[0], -1)
)
_target = data["targets"].view(-1)
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
_w = (_target != -100).int()
loss = non_reduced_loss.sum() / _w.sum().float()
global_loss = bmt.sum_loss(loss).item()
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
grad_accum_init_time = tim_usage["init"]
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
optim_manager.zero_grad()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
model_time = tim_usage["optim"] - grad_accum_init_time
ave_model_time.record(model_time)
else:
# dummy optim step
grad_norm = torch.Tensor([0.0]).cuda()
tim_usage["optim"] = tim_usage["backward"]
mem_usage["optim"] = mem_usage["backward"]
model_time = tim_usage["optim"] - tim_usage['init']
with torch.no_grad():
task_num = len(data["task_names"])
task_loss, task_token = get_task_loss_and_token(
non_reduced_loss, data["task_ids"], task_num, data["targets"]
)
task_loss_map: Dict[str, float] = {}
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
tot_task_token = gatherd_task_token_map.sum(dim=0)
ave_task_loss = sum_task_loss / tot_task_token
for i in range(task_num):
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
local_total_rate = torch.Tensor(
[data["lengths"].float().mean() / (args.max_length * args.batch_size)]
).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += (
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
)
bmt.print_rank(
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
)
last_before_log_time = tim_usage["before_log"]
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
iter_time = tim_usage["before_log"] - last_before_log_time
ave_iter_time.record(iter_time)
train_info = {
"time": iter_time,
"iteration": iteration,
"loss": global_loss,
"lr": lr_scheduler.current_lr,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
"grad_norm": grad_norm.item(),
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
"task_loss": task_loss_map,
"total_task_token": global_total_task_token,
}
global_token_pass_str = convert_to_k_and_b(global_token_pass)
time_report_str = "{model_time:.2f}={forward_time:.2f}+{backward_time:.2f}+{optim_time:.2f}".format(model_time=model_time, forward_time=tim_usage['forward']-tim_usage['init'], backward_time=tim_usage['backward']-tim_usage['forward'], optim_time=tim_usage['optim'] - tim_usage['backward'])
bmt.print_rank(
(
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
+ "{global_token_pass} | mem_usage {mem_usage} | "
).format(
iteration=iteration,
loss=global_loss,
lr=lr_scheduler.current_lr,
model_time=time_report_str,
iter_time=iter_time,
chunk_ave_time=ave_iter_time.get(),
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
grad_norm=grad_norm.item(),
global_token_pass=global_token_pass_str,
mem_usage=max([value for key, value in mem_usage.items()]),
)
)
bmt.print_rank(
"task_loss:\t| "
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
+ " |"
)
if iteration % 10 == 0:
bmt.print_rank(
"task_tokens (B):\t| "
+ " | ".join(
[
"{}: {:.4f}".format(task_name, task_token / 10**9)
for task_name, task_token in global_total_task_token.items()
]
)
+ " |"
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
if args.log_dir is not None and bmt.rank() == 0:
if args.save is not None:
save_every_step_stats(train_info, args.save)
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar("Loss/train", global_loss, iteration)
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
for task_name, loss in task_loss_map.items():
if not math.isnan(loss):
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
log_ckpt = {
"global_total_task_token": global_total_task_token,
"global_token_pass": global_token_pass,
"iteration": iteration,
}
if args.save is not None and iteration % args.save_iters == 0:
exporter.export(
model,
mixed_indexed_dataset,
tokenizer,
optimizer,
iteration,
args,
log_ckpt=log_ckpt,
final_save=False,
async_save=args.async_save,
)
if iteration == args.train_iters and args.stop_when_end == 1:
break
except Exception as e:
print(f"train loop err: {e}")
raise e
finally:
pass
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
def convert_to_k_and_b(number):
if number >= 1e9: # 大于或等于10亿
b_number = number / 1e9
return f"{b_number:.2f}B"
elif number >= 1e6: # 大于或等于1百万
k_number = number / 1e6
return f"{k_number:.2f}M"
elif number >= 1e3:
k_number = number / 1e3
return f"{k_number:.2f}K"
else:
return str(number)
def main():
args = initialize()
bmt.synchronize()
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
bmt.print_rank("finish loading")
bmt.print_rank(
"Number of parameter {}, Number of non-e parameter {}".format(
num_parameters(model), num_non_embedding_parameters(model)
)
)
bmt.print_rank("args: {}".format(args))
print("begining training")
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
if __name__ == "__main__":
main()

View File

@ -1,234 +0,0 @@
#!/bin/bash
#export OMP_NUM_THREADS=16
declare -A args # Declare an associative array to store arguments and values
args["model_unique"]="8b_0702"
args["resume_ckpt"]=""
args["config"]="8b"
args["flash"]="cuda"
args["batch_size"]="1"
args["max_length"]="4096"
args["save_iters"]="500"
args["train_iters"]="10"
args["dataset_config"]="fm9g_sft"
args["local"]="False"
args["dataloader"]="indexed"
args["save"]="True"
args["dataloader_num_threads"]=1
args["dataloader_prefetch"]=2
args["dataloader_prefetch_factor"]=32
args["dataloader_num_workers"]=2
args["lr"]="1e-5"
args["warmup_iters"]="20"
args["drop_iters"]="0.1"
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
args["load_grad"]="False"
args["grad_ckpt_num"]="160"
args["exp_group"]=""
args["ignore_cuda_oom"]="1"
args["tensorboard_all_tasks"]="0"
args["stop_when_end"]="0"
args["only_run_dataloader"]="0"
args["eps"]="1e-6"
args["inspect_iters"]="100"
args["strict_state_dict"]="1"
args["only_load_model"]="1"
args["lr_scheduler"]="cosine"
args["resume_no_optimze"]="0"
args["tp_size"]="1"
args["parallel_load_datastate"]="16"
args["async_save"]="False"
args["load_dataloader_ckpt"]="0"
args["drop_begin"]="-1"
args["drop_rate"]="0.5"
args["use_checkpoint"]="1"
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# 使用 Python 读取 JSON 文件并更新 Bash 字典
while read -r key value; do
args["$key"]="$value"
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
# 用cmd arg 再更新一次
# Loop through the arguments
for ((i=1; i<=$#; i++)); do
arg="${!i}"
# Check if the argument starts with "--"
if [[ "$arg" == --* ]]; then
arg_name="${arg:2}" # Remove leading "--"
valueid=$((i+1))
# Get the value of the argument if it exists
if ((i+1 <= $#)); then
args["$arg_name"]="${!valueid}"
i=$((i+1)) # Skip the next argument (its value)
else
args["$arg_name"]="" # Set empty value if no value provided
fi
fi
done
# Print the values of the arguments
echo "----------- CMD args ----------"
for key in "${!args[@]}"; do
echo "$key: ${args[$key]}"
done
echo "--------- END CMD args --------"
if [[ ${args["flash"]} == "triton" ]]; then
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
echo "triton flash"
fi
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
# GPUS_PER_NODE=1
echo "Using ${GPUS_PER_NODE} GPU each machine"
if [[ ${args["model_unique"]} == "" ]]; then
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
else
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
fi
echo "model_unique: "$MODEL_UNIQUE
# --------------- 运行参数 ---------------
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
OPTS+=" --batch-size ${args["batch_size"]}"
OPTS+=" --train-iters ${args["train_iters"]}"
OPTS+=" --save-iters ${args["save_iters"]}"
OPTS+=" --save-name fm9g_live_checkpoint"
OPTS+=" --max-length ${args["max_length"]}"
OPTS+=" --lr ${args["lr"]}"
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
OPTS+=" --drop-iters ${args["drop_iters"]}"
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
OPTS+=" --offload"
OPTS+=" --vocab ./tokenizer/vocab.txt"
OPTS+=" --flash ${args["flash"]}"
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
OPTS+=" --eps ${args["eps"]}"
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
OPTS+=" --only_load_model ${args["only_load_model"]}"
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
OPTS+=" --weight-decay 0.1"
OPTS+=" --tp-size ${args["tp_size"]}"
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
OPTS+=" --drop_begin ${args["drop_begin"]}"
OPTS+=" --drop_rate ${args["drop_rate"]}"
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
if [[ ${args["load_grad"]} == "True" ]]; then
OPTS+=" --load-grad"
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
fi
if [[ ${args["async_save"]} == "True" ]]; then
OPTS+=" --async_save"
fi
if [[ ${args["dataloader"]} == "indexed" ]]; then
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
fi
# --------------- 写文件路径 ---------------
## checkpoint
if [[ ${args["save"]} == "True" ]]; then
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
else
echo "won't save model"
fi
## logs/local/logs 等价于 ./datalogs软链
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
if [[ ${args["local"]} == "True" ]]; then
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
else
current_dir=$(pwd)
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
echo "Platform config:"${PLATFORM_CONFIG_PATH}
fi
## checkpoint兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint启动会比较快
if [ "${args["resume_ckpt"]}" != "" ]; then
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
else
echo "No checkpoint to load"
fi
filename="pretrain_dragonfly"
if [[ ${args["local"]} == "True" ]]; then
PRETRAIN_ENTRY="$filename.py"
else
PRETRAIN_ENTRY="$filename.py"
fi
GPUS_PER_NODE=8
NNODES=1
RANK=0
MASTER_ENDPOINT=g3006
MASTER_PORT=12345
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
$CMD

View File

@ -1,9 +0,0 @@
{
"pretrain": {
"train_iters": 20000,
"batch_size": 1,
"max_length": 4096,
"n_gpus": 8,
"lr": 1e-5
}
}

126
FM_9G/chat_model.py Normal file
View File

@ -0,0 +1,126 @@
import gc
from io import BytesIO
import requests
import timm
import torch
from PIL import Image
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from torchvision.transforms import InterpolationMode, transforms
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import os,sys
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.utils.constants import SYSTEM
def chat(model, image, question, context, tokenizer, query_nums=64, vision_hidden_states=None, max_length=1024):
if not context:
question = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + question
final_input = f'{SYSTEM}<用户>{question}<AI>'
else:
final_input = f'{context}<用户>{question}<AI>'
data_list = [
{'input': final_input}
]
res, vision_hidden_states = model.generate(
data_list=data_list,
max_inp_length=2048,
beam_size=3,
img_list=[[image]],
max_length=max_length,
repetition_penalty=1.1,
temperature=0.7,
length_penalty=3,
return_vision_hidden_states=True
)
answer = res[0]
context = final_input + answer
return answer, context, vision_hidden_states
def load_llm(llm_path):
config = FM9GConfig.from_json_file(llm_path)
config.use_flash_attn = False
cpm_model = FM9GTorch(config)
return cpm_model
def load_vpm(vision_encoder, drop_vision_last_layer=False):
model = timm.create_model(
vision_encoder,
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True
)
if isinstance(model, timm.models.VisionTransformer):
if model.attn_pool is not None:
model.attn_pool = torch.nn.Identity()
if drop_vision_last_layer:
model.blocks[-1] = torch.nn.Identity()
return model
def load_vis_fm9g(llm_path, vision_encoder):
llm =load_llm(llm_path)
vpm = load_vpm(vision_encoder, drop_vision_last_layer=False)
vision_dim = vpm.embed_dim
model = VLU_FM9G(llm, vpm, vision_dim, query_num=256)
return model
def load_tokenizer(vocabs_path):
return FM9GTokenizer(vocabs_path)
def load_transform(img_size):
transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
])
return transform
if __name__ == '__main__':
root = "checkpoint/"
llm_path = root + 'config.json'
vocabs_path = root + 'vocabs.txt'
model_checkpoint = root + 'sharded'
vision_encoder = 'eva02_enormous_patch14_clip_224.laion2b_plus'
img_size = 448
with init_empty_weights():
model = load_vis_fm9g(llm_path, vision_encoder)
model = load_checkpoint_and_dispatch(model, model_checkpoint, device_map="auto", max_memory={0: "24GiB", 1: "24GiB"}, no_split_module_classes=['EvaBlockPostNorm'])
model.eval()
tokenizer = load_tokenizer(vocabs_path)
transform = load_transform(img_size)
beam_search = VLLMFM9GBeamSearch(model, tokenizer, transform)
# 图像输入
url = 'test.jpg'
image = Image.open(url).convert('RGB')
# 文本输入
prompt = '这幅图描述了什么?'
answer, context, _ = chat(
beam_search, image, prompt, context=None, tokenizer=tokenizer, query_nums=256
)
print(answer)

View File

@ -1 +0,0 @@
from .arguments import get_args

View File

@ -1,425 +0,0 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
def add_model_config_args(parser: argparse.ArgumentParser):
"""Model arguments"""
group = parser.add_argument_group("model", "model configuration")
group.add_argument("--model-config", type=str, help="model configuration file")
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
# group.add_argument("--qk_norm", action="store_true", default=False, help="qk layernorm")
return parser
def add_training_args(parser: argparse.ArgumentParser):
"""Training arguments."""
group = parser.add_argument_group("train", "training configurations")
group.add_argument("--platform-config", type=str, default="platform_config.json", help="Path to platform config")
group.add_argument("--dataset", type=str, default="dataset.json", help="Path to dataset")
group.add_argument("--val-dataset", type=str, default="dataset.json", help="Path to val dataset")
group.add_argument(
"--load",
type=str,
default=None,
help="Path to a directory containing a model checkpoint.",
)
group.add_argument(
"--load-grad",
action="store_true",
default=False,
help="Load the gradient states",
)
group.add_argument(
"--grad-ckpt-num",
type=int,
default=0,
help="grad file num (only work when --load-grad from files less than world-size )",
)
group.add_argument(
"--load-start-step",
action="store_true",
default=False,
help="Load the step state from checkpoints",
)
group.add_argument(
"--save",
type=str,
default=None,
help="Output directory to save checkpoints to.",
)
group.add_argument(
"--save-name",
type=str,
default=None,
help="Output filename to save checkpoints to.",
)
group.add_argument(
"--save-model",
type=str,
default=None,
help="Output directory to save model to.",
)
group.add_argument(
"--tensorboard",
type=str,
default=None,
help="tensorboard directory",
)
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
group.add_argument("--num-micro-batches", type=int, default=16)
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
group.add_argument("--grad-accum", type=int, default=1, help="gradient accum steps")
group.add_argument(
"--train-iters",
type=int,
default=1000000,
help="total number of iterations to train over all training runs",
)
group.add_argument("--max-length", type=int, default=512, help="max length of input")
group.add_argument("--min-length", type=int, default=None, help="only for speed test")
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
# Learning rate.
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
group.add_argument("--lr_scheduler", type=str, default="cosine", help=" learning rate scheduler")
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
group.add_argument("--min-loss-scale", type=float, default=1, help="loss scale")
group.add_argument("--loss-scale-steps", type=float, default=1024, help="loss scale")
group.add_argument(
"--warmup-iters",
type=float,
default=0.01,
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
)
group.add_argument(
"--drop-iters",
type=float,
default=0.01,
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
)
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
group.add_argument("--use-jfs-data", action="store_true", help="whether we use juicefs dataset")
group.add_argument("--tp-size", default=1, type=int)
group.add_argument("--pp-size", default=1, type=int)
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
group.add_argument("--dataloader_num_threads", default=3, type=int, help="Only useful in indexed dataest.")
group.add_argument("--dataloader_prefetch", default=200, type=int, help="Only useful in indexed dataest.")
group.add_argument("--dataloader_num_workers", default=4, type=int, help="Only useful in indexed dataest.")
group.add_argument("--dataloader_prefetch_factor", default=50, type=int, help="Only useful in indexed dataest.")
group.add_argument(
"--dataloader",
default="indexed",
type=str,
help="dataloader type, 'indexed' for indexed dataset, 'normal' for normal dataset",
)
group.add_argument("--stop_when_end", default=0, type=int, help="Whether to stop training when we reach end_iter")
group.add_argument(
"--data_len_threshold",
default=512,
type=int,
help="If the average length of a sequence is less than this int, mean the sample is biased. ",
)
group.add_argument(
"--only_run_dataloader", default=0, type=int, help="Whether to only run dataloader to check data. "
)
group.add_argument(
"--only_load_model", default=0, type=int, help="Whether to only load a model ckpt, without anything else."
)
group.add_argument(
"--load_dataloader_ckpt", default=1, type=int, help="Whether to only load a model ckpt, without anything else."
)
group.add_argument(
"--resume_no_optimze",
default=0,
type=int,
help="The number of steps that does not add optimization after resume",
)
group.add_argument(
"--parallel_load_datastate",
default=256,
type=int,
help="The number of parallel workers to load dataset state",
)
group.add_argument(
"--async_save",
action="store_true",
help="whether to save artifacts asynchronously",
)
group.add_argument(
"--drop_begin",
default=-1,
type=int,
help="The number of steps that starts to drop lr"
)
group.add_argument(
"--drop_rate",
default=0.5,
type=float,
help="The number rate"
)
group.add_argument(
"--use_checkpoint",
default=1,
type=int,
help="Whether to use checkpointing."
)
return parser
def add_pretrain_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("pretrain", "pretrain configurations")
group.add_argument(
"--save-iters",
type=int,
default=1000,
help="number of iterations between saves",
)
group.add_argument(
"--log-dir",
type=str,
default=None,
help="log directory",
)
group.add_argument(
"--worker-name",
type=str,
default=None,
help="worker name",
)
return parser
def add_tokenizer_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("tokenizer", "tokenizer configurations")
group.add_argument(
"--tokenizer_path",
type=str,
default="",
help="tokenizer_path",
)
return parser
def add_finetune_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("finetune", "finetune configurations")
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
group.add_argument("--task-name", type=str, default="task", help="name of training task")
group.add_argument("--save-epochs", type=int, default=1, help="number of training epochs between saves")
group.add_argument("--save-steps", type=int, default=0, help="number of training steps between saves")
group.add_argument(
"--drop-last",
action="store_true",
default=False,
help="drop data from each epoch that cannot be formed into a complete batch at the end",
)
group.add_argument("--delta-tuning", action="store_true", default=False)
group.add_argument("--each-epoch-save", default=False)
group.add_argument("--train-task-id", type=int, default=-1)
return parser
def add_rhlf_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("rhlf", "rhlf configurations")
group.add_argument(
"--load-reward",
type=str,
default=None,
help="Path to reward model checkpoint.",
)
group.add_argument("--actor-lr", type=float, default=1.0e-5, help="actor learning rate")
group.add_argument("--critic-lr", type=float, default=1.0e-6, help="critic learning rate")
group.add_argument("--actor-loss-scale", type=float, default=65536, help="actor loss scale")
group.add_argument("--critic-loss-scale", type=float, default=65536, help="critic loss scale")
group.add_argument("--avg-reward-bias", type=float, default=0, help="reward bias")
group.add_argument("--actor-delay-step", type=int, default=0, help="actor delay step")
group.add_argument("--entropy-coef", type=float, default=-1.0, help="coef of policy entropy")
##
return parser
def add_simple_rhlf_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("simple_rhlf", "simple rhlf configurations")
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
group.add_argument("--sample-batch-size", type=int, default=32, help="Data Loader sample batch size")
group.add_argument("--load-reward", type=str, default=None, help="Path to reward model checkpoint")
group.add_argument("--avg-reward-bias", type=float, default=0, help="reward bias")
group.add_argument("--sample-min-length", type=int, default=20, help="sample-min-length")
group.add_argument("--sample-max-inp-length", type=int, default=1024, help="sample-max-inp-length")
group.add_argument("--sample-max-length", type=int, default=64, help="sample-max-length")
group.add_argument("--sample-repetition-penalty", type=float, default=1.05, help="sample-repetition-penalty")
group.add_argument("--sample-temperature", type=float, default=1.0, help="sample-temperature")
group.add_argument("--encode-max-length", type=int, default=1024, help="encode-max-length")
group.add_argument("--generate-max-length", type=int, default=64, help="generate-max-length")
group.add_argument("--value-loss-weight", type=float, default=0.1, help="value-loss-weight")
group.add_argument("--ptx-loss-weight", type=float, default=0.001, help="ptx-loss-weight")
group.add_argument("--save-epochs", type=int, default=1, help="number of training epochs between saves")
##
return parser
def add_feedback_learning_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("rrhf", "rrhf configurations")
group.add_argument("--length-penalty", type=float, default=1.0, help="length_penalty")
group.add_argument("--feedback-weight", type=float, default=1.0, help="feedback_weight")
group.add_argument("--sample-num", type=int, default=6, help="sample_num")
group.add_argument("--dpo-beta", type=float, default=1.0, help="dpo_beta")
group.add_argument("--stable-alignment-margin", type=float, default=1.0, help="stable_alignment_margin")
group.add_argument("--feedback-learning-type", type=str, default="RRHF", help="feedback_learning_type")
group.add_argument("--save-iters", type=int, default=1000, help="number of iterations between saves")
##
return parser
def add_model_change_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("model_change", "model change during pretraining")
group.add_argument("--strict_state_dict", type=int, default=1, help="strict_state_dict")
##
return parser
def add_log_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("log", "log configurations")
group.add_argument("--tensorboard_all_tasks", type=int, default=0, help="log")
return parser
def add_error_handle_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("error_handle", "error_handle configurations")
group.add_argument(
"--ignore_cuda_oom", type=int, default=1, help="continue training by ingore the batch that causes oom"
)
return parser
def add_runtime_eval_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("runtime eval args", "runtime evaluation by submitting a job")
group.add_argument(
"--runtime_eval",
action="store_true",
help="whether to use runtime_eval. Only if this is set to True, the following variables will be useful",
)
group.add_argument("--eval_jeeves_auth", type=str, default="", help="auth, press f12 on jeeves platform to get")
group.add_argument("--eval_project_id", type=str, default=None, help="project id")
group.add_argument("--eval_run_cmd", type=str, default="", help="cmd for eval")
group.add_argument(
"--eval_git_path",
type=str,
default="git@git.in.zhihu.com:luca/llm-bench.git",
help="git path of evaluation code",
)
group.add_argument("--eval_git_branch", type=str, default="master", help="git branch of evaluation code")
group.add_argument("--eval_node_num", type=int, default=1, help="using 1 node to evaluate")
group.add_argument("--eval_gpu_num", type=int, default=1, help="using 1 gpu per node to evaluate")
group.add_argument("--eval_tasks_config", type=str, default="", help="evaluate tasks' config")
group.add_argument("--eval_model_backend", default="torch", type=str, help="model_backend")
group.add_argument(
"--eval_at_start", action="store_true", help="whether to eval at the first epoch, default to false"
)
return parser
def add_reward_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group("reward", "reward configurations")
group.add_argument("--load-all", type=str, default=None, help="Path to a directory containing a model checkpoint.")
##
return parser
def add_long_context_extend_args(parser: argparse.ArgumentParser):
"""long context extending arguments."""
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
group.add_argument("--pose-prob", default=0.0, type=float, help="Sample-level PoSE probability")
group.add_argument(
"--pose-scaling-factor",
default=1.0,
type=float,
help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor",
)
group.add_argument(
"--rope-scaling-type",
default="",
type=str,
choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],
help="Context scaling type",
)
group.add_argument("--rope-scaling-factor", default=1, type=int, help="Context scaling factor")
group.add_argument(
"--orig-max-length", default=8192, type=int, help="Original context length before context extending"
)
return parser
def get_args(
pretrain: bool = False,
finetune: bool = False,
rhlf: bool = False,
simple_rlhf: bool = False,
feedback_learning: bool = False,
reward: bool = False,
):
parser = argparse.ArgumentParser()
parser = add_model_config_args(parser) # config file need to be exported with model/ckpt
parser = add_training_args(parser)
if pretrain:
parser = add_pretrain_args(parser)
parser = add_runtime_eval_args(parser)
parser = add_tokenizer_args(parser)
parser = add_log_args(parser)
parser = add_error_handle_args(parser)
parser = add_model_change_args(parser)
if finetune:
parser = add_finetune_args(parser)
if rhlf:
parser = add_rhlf_args(parser)
if simple_rlhf:
parser = add_simple_rhlf_args(parser)
if feedback_learning:
parser = add_feedback_learning_args(parser)
if reward:
parser = add_reward_args(parser)
parser = add_long_context_extend_args(parser)
args = parser.parse_args()
return args

View File

@ -1,16 +0,0 @@
from .distributed_dataset import build_dataset
from .distributed_dataset import DistributedDataset
from .distributed_dataset import SimpleDataset
from .indexed_dataset import IndexedDataset
from .indexed_dataset import IndexedDatasetBuilder
from .indexed_dataset import PrefetchDecodeDataset
# from .list_dataset import ListDataset
from .utils import compact_dataset
from .utils import CudaPrefetcher
from .utils import mask_dataset
from .utils import merge_dataset
from .utils import random_range
from .utils import Range
from .utils import shuffle_dataset
from .utils import ThreadedPrefetcher

View File

@ -1,814 +0,0 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect
import io
import json
import os
import random
import string
import struct
import time
from typing import List
from typing import Optional
from typing import Set
import bmtrain as bmt
import torch
from .serializer import PickleSerializer
from .serializer import Serializer
def _random_string():
return "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
_DEFAULT_BLOCK_SIZE = 16 << 20
class FileInfo:
def __init__(
self,
file_name: str = "",
block_begin: int = 0,
block_end: int = 0,
nbytes: int = 0,
nlines: int = 0,
mask: bool = False,
block_size: int = _DEFAULT_BLOCK_SIZE,
) -> None:
self.file_name = file_name
self.block_begin = block_begin
self.block_end = block_end
self.nbytes = nbytes
self.nlines = nlines
self.mask = mask
self.block_size = block_size
def state_dict(self):
return {
"file_name": self.file_name,
"block_begin": self.block_begin,
"block_end": self.block_end,
"nbytes": self.nbytes,
"nlines": self.nlines,
"mask": self.mask,
"block_size": self.block_size,
}
def load_state_dict(self, d):
self.file_name = d["file_name"]
self.block_begin = d["block_begin"]
self.block_end = d["block_end"]
self.nbytes = d["nbytes"]
self.nlines = d["nlines"]
self.mask = d["mask"]
self.block_size = d["block_size"]
def dumps(self) -> str:
return json.dumps(self.state_dict())
def loads(self, data: str) -> "FileInfo":
self.load_state_dict(json.loads(data))
return self
def dump(self, fp: io.TextIOWrapper) -> "FileInfo":
fp.write(self.dumps())
return self
def load(self, fp: io.TextIOWrapper) -> "FileInfo":
self.loads(fp.read())
return self
def _read_info_list(meta_path: str) -> List[FileInfo]:
info: List[FileInfo] = []
while True:
try:
with open(meta_path, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip()
if len(line) > 0:
info.append(FileInfo().loads(line))
return info
except Exception as e:
print(
"Error: reading info list in _read_info_list!, meta_path={path}, err={err}".format(
path=meta_path, err=str(e)
)
)
time.sleep(10)
def _write_info_list(meta_path: str, info: List[FileInfo]):
base_path = os.path.dirname(meta_path)
random_fname = os.path.join(base_path, ".meta.bin.%s" % _random_string())
while True:
try:
with open(random_fname, "w", encoding="utf-8") as f:
for v in info:
f.write(v.dumps() + "\n")
os.rename(random_fname, meta_path)
return
except Exception:
print("Error: writing info list!")
time.sleep(10)
def _filtered_range(begin: int, end: int, rank: int, world_size: int, filter_set: Optional[Set[int]] = None):
begin = begin + (rank + (world_size - (begin % world_size))) % world_size
if filter_set is not None:
return [i for i in range(begin, end, world_size) if i in filter_set]
else:
return [i for i in range(begin, end, world_size)]
class SafeFile:
def __init__(self, fname, mode):
self.fname = None
self.mode = None
self._fp = None
self.open_file(fname, mode)
def read(self, size=-1):
if self._fp is None:
raise RuntimeError("Dataset is closed")
try:
res = self._fp.read(size)
self.offset = self._fp.tell()
return res
except Exception as e:
print("Error: reading blocks in {}! err {}".format(self.fname, str(e)))
self.close()
self.open_file(self.fname, self.mode, self.offset)
return self.read(size)
def tell(self):
if self._fp is None:
raise RuntimeError("Dataset is closed")
try:
res = self._fp.tell()
self.offset = res
return res
except Exception as e:
print("Error: telling blocks in {}! err {}".format(self.fname, str(e)))
self.close()
self.open_file(self.fname, self.mode, self.offset)
return self.tell()
def seek(self, offset, whence=0):
if self._fp is None:
raise RuntimeError("Dataset is closed")
try:
res = self._fp.seek(offset, whence)
self.offset = self._fp.tell()
return res
except Exception as e:
print("Error: seeking blocks in {}! err {}".format(self.fname, str(e)))
self.close()
self.open_file(self.fname, self.mode, self.offset)
return self.seek(offset, whence)
def close(self):
if self._fp is not None:
try:
self._fp.close()
except Exception as e:
print("Error: closing blocks in {}! err {}".format(self.fname, str(e)))
self._fp = None
def open_file(self, fname, mode, offset=None):
if not os.path.exists(fname):
print("Dataset {} does not exist".format(fname))
self.close()
time.sleep(20)
self.open_file(fname, mode, offset)
try:
self.fname = fname
self.mode = mode
self._fp = open(fname, mode)
if offset is not None:
self._fp.seek(offset, io.SEEK_SET)
self.offset = self._fp.tell()
except Exception as e:
print("Error: opening blocks in {}! err {}".format(self.fname, str(e)))
self.close()
time.sleep(20)
self.open_file(fname, mode, offset)
class DistributedDataset:
"""Open dataset in readonly mode.
`DistributeDataset` is used to read datasets in a distributed manner.
Data in this dataset will be distributed evenly in blocks to each worker in the `distributed communicator`.
**Note** When all data has been read, reading dataset again will revert back to the first data.
Args:
path (str): Path to dataset.
rank (int): Rank in distributed communicator. See: bmtrain.rank()
world_size (int): Total workers in distributed communicator. See: bmtrain.world_size()
block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB
Example:
>>> dataset = DistributedDataset("/path/to/dataset")
>>> for i in range(10):
>>> dataset.read()
""" # noqa: E501
def __init__(
self,
path: str,
rank: int = 0,
world_size: int = 1,
serializer: Optional[Serializer] = None,
max_repeat_times: Optional[int] = None,
shuffle: bool = True,
) -> None:
# config
self._path = path
self._rank = rank
self._world_size = world_size
self._max_repeat_times = max_repeat_times
self._repeat_times = 0
self._shuffle = shuffle
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer
# dataset meta
self._unused_block: List[int] = []
self._unused_block_offset: List[int] = []
self._file_info: List[FileInfo] = []
self._file_ends: List[int] = []
self._total_blocks = 0
self._nbytes = 0
self._nlines = 0
# states
self._curr_block = None
self._fp = None
# cache
self._last_mod_time = 0
self._curr_fname = None
self._update_states(fast_skip=False)
self._repeat_times += 1
def _update_states(self, fast_skip: bool = True):
meta_path = os.path.join(self._path, "meta.bin")
while True:
try:
mod_time = os.stat(meta_path).st_mtime
break
except Exception as e:
print(
"Error: reading info list in DistributedDataset._update_states, "
"meta_path={path}, err={err}!".format(path=meta_path, err=str(e))
)
time.sleep(10)
if self._last_mod_time < mod_time:
# file changed
self._last_mod_time = mod_time
else:
if fast_skip:
return
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
old_len = len(self._file_info)
if old_len > len(info):
raise RuntimeError("Dataset meta file: changed unexpectly")
mask_changed = False
for i in range(old_len):
if self._file_info[i].file_name != info[i].file_name:
raise RuntimeError("Dataset meta file: changed unexpectly")
if self._file_info[i].block_begin != info[i].block_begin:
raise RuntimeError("Dataset meta file: changed unexpectly")
if self._file_info[i].block_end != info[i].block_end:
raise RuntimeError("Dataset meta file: changed unexpectly")
if self._file_info[i].mask != info[i].mask:
mask_changed = True
if info[0].block_begin != 0:
raise RuntimeError("Dataset meta file: block error (0)")
for i in range(len(info) - 1):
if info[i].block_end != info[i + 1].block_begin:
raise RuntimeError("Dataset meta file: block error (%d)" % (i + 1))
if (old_len == len(info) and not mask_changed) and fast_skip:
# fast skip
return
if len(info) > 0:
total_blocks = info[-1].block_end
self._nbytes = 0
self._nlines = 0
for v in info:
self._nbytes += v.nbytes
self._nlines += v.nlines
else:
total_blocks = 0
self._nbytes = 0
self._nlines = 0
if total_blocks > 0:
unused_block_set = set(self._unused_block)
nw_unused_block: List[int] = []
for i in range(len(info)):
v = info[i]
if not v.mask:
if i < old_len:
nw_unused_block.extend(
_filtered_range(
v.block_begin,
v.block_end,
self._rank,
self._world_size,
unused_block_set,
)
)
else:
nw_unused_block.extend(
_filtered_range(v.block_begin, v.block_end, self._rank, self._world_size)
)
# re-shuffle unused blocks
if self._shuffle:
random.shuffle(nw_unused_block)
offset_dict = {block: offset for block, offset in zip(self._unused_block, self._unused_block_offset)}
nw_unused_block_offset = [offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block]
self._unused_block = nw_unused_block
self._unused_block_offset = nw_unused_block_offset
self._file_ends = []
for v in info:
self._file_ends.append(v.block_end)
else:
self._unused_block = []
self._unused_block_offset = []
self._file_ends = []
self._total_blocks = total_blocks
self._file_info = info
assert len(self._unused_block) == len(self._unused_block_offset)
assert len(self._file_ends) == len(self._file_info)
def _mask_file(self, f: FileInfo):
nw_unused_block: List[int] = []
nw_unused_block_offset: List[int] = []
for block_id, block_offset in zip(self._unused_block, self._unused_block_offset):
if block_id < f.block_begin or block_id >= f.block_end:
nw_unused_block.append(block_id)
nw_unused_block_offset.append(block_offset)
self._unused_block = nw_unused_block
self._unused_block_offset = nw_unused_block_offset
def _get_block_file(self, block_id: int):
# find block in which file
file_idx = bisect.bisect_right(self._file_ends, block_id)
return self._file_info[file_idx]
def _prepare_new_epoch(self):
if self._max_repeat_times is not None:
if self._repeat_times >= self._max_repeat_times:
raise EOFError("End of dataset")
nw_unused_block: List[int] = []
for v in self._file_info:
if not v.mask:
nw_unused_block.extend(_filtered_range(v.block_begin, v.block_end, self._rank, self._world_size))
if self._shuffle:
random.shuffle(nw_unused_block)
self._unused_block = nw_unused_block
self._unused_block_offset = [0 for _ in nw_unused_block]
self._repeat_times += 1
def _get_next_block(self):
self._update_states()
if len(self._unused_block) == 0:
self._prepare_new_epoch()
if len(self._unused_block) == 0:
raise RuntimeError("Empty dataset {}".format(self._path))
mn_block: int = self._unused_block.pop()
mn_block_offset: int = self._unused_block_offset.pop()
return mn_block, mn_block_offset
def _state_dict(self):
self._update_states()
num_unused_block = len(self._unused_block)
if (self._fp is not None) and (self._curr_block is not None):
curr_block = self._curr_block
curr_f = self._get_block_file(curr_block)
inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
else:
curr_block = -1
inblock_offset = 0
return {
"states": torch.tensor(self._unused_block, dtype=torch.long, device="cpu"),
"offset": torch.tensor(self._unused_block_offset, dtype=torch.long, device="cpu"),
"block": torch.tensor(
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
dtype=torch.long,
device="cpu",
),
}
def state_dict(self):
"""Returns a state dict representing the read states of the dataset.
Example:
>>> state = dataset.state_dict()
>>> dataset.load_state_dict(state)
"""
self._update_states()
num_unused_block = len(self._unused_block)
if (self._fp is not None) and (self._curr_block is not None):
curr_block = self._curr_block
curr_f = self._get_block_file(curr_block)
inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
else:
curr_block = -1
inblock_offset = 0
with torch.no_grad():
if self._world_size > 1:
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
max_unused_blocks = (
bmt.distributed.all_reduce(gpu_num_unused_block, op="max", comm=bmt.config["tp_zero_comm"])
.cpu()
.item()
)
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
gpu_offset[:num_unused_block] = torch.tensor(self._unused_block_offset, dtype=torch.long).cuda()
gpu_block = torch.tensor(
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
dtype=torch.long,
).cuda()
global_states = bmt.distributed.all_gather(
gpu_states, comm=bmt.config["tp_zero_comm"]
).cpu() # (world_size, max_unused_blocks)
global_offset = bmt.distributed.all_gather(
gpu_offset, comm=bmt.config["tp_zero_comm"]
).cpu() # (world_size, max_unused_blocks)
global_block = bmt.distributed.all_gather(
gpu_block, comm=bmt.config["tp_zero_comm"]
).cpu() # (world_size, 4)
return {"states": global_states, "offset": global_offset, "block": global_block}
else:
return {
"states": torch.tensor([self._unused_block], dtype=torch.long, device="cpu"),
"offset": torch.tensor([self._unused_block_offset], dtype=torch.long, device="cpu"),
"block": torch.tensor(
[[curr_block, inblock_offset, num_unused_block, self._repeat_times]],
dtype=torch.long,
device="cpu",
),
}
def load_state_dict(self, state, strict: bool = True):
"""Load dataset state.
Args:
state (dict): dataset state dict.
strict (bool): If `strict` is True, world size needs to be the same as when exported.
Example:
>>> state = dataset.state_dict()
>>>
"""
block_states: torch.LongTensor = state["states"]
block_info: torch.LongTensor = state["block"]
if "offset" not in state:
block_offset: torch.LongTensor = torch.zeros_like(block_states).long()
else:
block_offset: torch.LongTensor = state["offset"]
if block_states.size(0) != self._world_size:
if strict:
raise ValueError("world_size changed (%d -> %d)" % (state["block"].size(0), self._world_size))
else:
self._curr_block = None
self._fp = None
self._curr_fname = None
self._repeat_times = int(block_info[0, 3].item())
offset_dict = {}
for i in range(block_states.size(0)):
for block, offset in zip(block_states[i].tolist(), block_offset[i].tolist()):
offset_dict[block] = offset
# re-shuffle unused blocks
nw_unused_block: List[int] = []
for i in range(block_states.size(0)):
_, _, num_unused_blocks, _ = block_info[i].tolist()
nw_unused_block.extend(
[
block_id
for block_id in block_states[i, :num_unused_blocks].tolist()
if block_id % self._world_size == self._rank
]
)
for i in range(block_states.size(0)):
curr_block, inblock_offset, num_unused_blocks, _ = block_info[i].tolist()
if curr_block < 0:
continue
if curr_block % self._world_size == self._rank:
nw_unused_block.append(curr_block)
offset_dict[curr_block] = inblock_offset
curr_block, inblock_offset, _, self._repeat_times = block_info[self._rank].tolist()
# if self._shuffle:
# random.shuffle(nw_unused_block)
nw_unused_block_offset = [
offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block
]
self._unused_block = nw_unused_block
self._unused_block_offset = nw_unused_block_offset
else:
curr_block, inblock_offset, num_unused_blocks, self._repeat_times = block_info[self._rank].tolist()
if curr_block == -1:
self._curr_block = None
self._unused_block = []
self.unused_block_offset = []
else:
while True:
try:
self._curr_block = curr_block
f_info = self._get_block_file(self._curr_block)
self._open_file(
f_info.file_name,
(self._curr_block - f_info.block_begin) * f_info.block_size + inblock_offset,
)
self._unused_block = block_states[self._rank, :num_unused_blocks].tolist()
self.unused_block_offset = block_offset[self._rank, :num_unused_blocks].tolist()
break
except Exception:
print("Error: reading blocks in {}".format(f_info.file_name))
time.sleep(10)
# end
self._update_states()
def _get_file_path(self, fname):
return os.path.join(self._path, fname)
def _open_file(self, fname, offset):
if self._curr_fname != fname:
if self._fp is not None:
self._fp.close()
self._curr_fname = None
# self._fp = open(self._get_file_path(fname), "rb")
self._fp = SafeFile(self._get_file_path(fname), "rb")
self._curr_fname = fname
else:
assert self._fp is not None, "Unexpected error"
self._fp.seek(offset, io.SEEK_SET) # move to block
def read(self):
"""Read a piece of data from dataset.
Workers in different ranks will read different data.
"""
if self._curr_block is None:
next_block_id, next_block_offset = self._get_next_block()
f_info = self._get_block_file(next_block_id)
try:
self._open_file(
f_info.file_name, (next_block_id - f_info.block_begin) * f_info.block_size + next_block_offset
)
self._curr_block = next_block_id
except FileNotFoundError:
print("ERR: reading again!")
self._mask_file(f_info)
return self.read() # read again
if self._fp is None:
raise RuntimeError("Dataset is not initialized")
MAGIC = self._fp.read(1)
if MAGIC == b"\x1F":
# correct
size = struct.unpack("I", self._fp.read(4))[0]
data = self._fp.read(size)
return self.serializer.deserialize(data)
elif MAGIC == b"\x00":
# end of block
self._curr_block = None
return self.read() # read next block
else:
raise ValueError("Invalid magic header")
@property
def nbytes(self):
return self._nbytes
class SimpleDataset(DistributedDataset):
def __init__(
self,
path: str,
serializer: Optional[Serializer] = None,
shuffle: bool = True,
) -> None:
super().__init__(
path,
0,
1,
serializer=serializer,
max_repeat_times=1,
shuffle=shuffle,
)
def __iter__(self):
while True:
try:
data = self.read()
except EOFError:
self._repeat_times = 0
break
yield data
def __len__(self):
return self._nlines
def get_bytes(self):
return self._nbytes
class DatasetWriter:
def __init__(self, fname: str, block_size: int, serializer: Optional[Serializer] = None):
self._fname = fname
self._block_size = block_size
self._fp = open(self._fname, "wb")
self._inblock_offset = 0
self._nbytes = 0
self._nlines = 0
self._nblocks = 1
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer
def write(self, data):
"""Write a piece of data into dataset.
Args:
data (Any): Serialization will be done using pickle.
Example:
>>> writer.write( "anything you want" )
"""
byte_data = self.serializer.serialize(data)
byte_data = struct.pack("I", len(byte_data)) + byte_data
if self._inblock_offset + 2 + len(byte_data) > self._block_size:
self._fp.write(b"\x00" * (self._block_size - self._inblock_offset)) # fill the remaining space with 0
self._inblock_offset = 0
self._nblocks += 1
# we go to the next block
if self._inblock_offset + 2 + len(byte_data) > self._block_size:
raise ValueError("data is larger than block size")
self._nbytes += len(byte_data)
self._nlines += 1
self._inblock_offset += 1 + len(byte_data)
self._fp.write(b"\x1F")
self._fp.write(byte_data)
@property
def nbytes(self):
return self._nbytes
@property
def nblocks(self):
return self._nblocks
@property
def nlines(self):
return self._nlines
def close(self):
if not self._fp.closed:
self._fp.write(b"\x00" * (self._block_size - self._inblock_offset))
self._fp.close()
class DatasetBuilder:
def __init__(
self,
path: str,
dbname: str,
block_size=_DEFAULT_BLOCK_SIZE,
serializer: Optional[Serializer] = None,
) -> None:
self._block_size = block_size
self._path = path
self._dbname = dbname
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer
if not os.path.exists(self._path):
os.makedirs(self._path)
meta_path = os.path.join(self._path, "meta.bin")
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
for v in info:
if v.file_name == dbname:
raise ValueError("Dataset name exists")
self._db_path = os.path.join(self._path, self._dbname)
if os.path.exists(self._db_path):
raise ValueError("File exists `%s`" % self._db_path)
def __enter__(self):
self._writer = DatasetWriter(self._db_path, self._block_size, self.serializer)
return self._writer
def __exit__(self, exc_type, exc_value, exc_traceback):
if self._writer is None:
raise RuntimeError("Unexpected call to __exit__")
self._writer.close()
if exc_type is not None:
print("Error while writing file")
if os.path.exists(self._db_path):
os.unlink(self._db_path)
else:
meta_path = os.path.join(self._path, "meta.bin")
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
last_block = 0
if len(info) > 0:
last_block = info[-1].block_end
info.append(
FileInfo(
self._dbname,
last_block,
last_block + self._writer.nblocks,
self._writer.nbytes,
self._writer.nlines,
False,
self._block_size,
)
)
# atomic write to meta file
_write_info_list(meta_path, info)
self._writer = None
def build_dataset(
path: str,
dbname: str,
block_size: int = _DEFAULT_BLOCK_SIZE,
serializer: Optional[Serializer] = None,
):
"""Open the dataset in write mode and returns a writer.
Args:
path (str): Path to dataset.
dbname (str): The name of the file to which the data will be written. The `dbname` needs to be unique in this `dataset`.
block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB
Example:
>>> with build_dataset("/path/to/dataset", "data_part_1") as writer:
>>> for i in range(10):
>>> writer.write( { "anything you want" } )
""" # noqa: E501
return DatasetBuilder(path, dbname, block_size=block_size, serializer=serializer)

View File

@ -1,460 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
"""
使用 IndexedDataset 前需按指定格式构建或者转换已有数据集
数据集文件结构
- <dataset name>
- data.jsonl # jsonl 格式的数据,每一行一条样本
- index # 记录每一行 json 数据的起始 byte-offset
从头构建直接使用 IndexedDatasetBuilder 这个 context manager
>>> with IndexedDatasetBuilder("swear", overwrite=True) as builder:
>>> for data in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
>>> builder.put(data)
转换
fm9g distributed_dataset 转换使用 `fm9g.dataset.tools.distributed_to_indexed`
$ python -m fm9g.dataset.tools.distributed_to_indexed -i <原数据集文件夹> -o <新数据集文件夹>
已有 jsonl 数据使用 `fm9g.dataset.tools.jsonl_to_index` 构建 index 文件需提前先把 jsonl 文件命名为
$ python -m fm9g.dataset.tools.jsonl_to_index -p <数据集文件夹路径>
"""
import itertools
import math
import os
import queue
import random
import threading
import time
import bmtrain as bmt
import h5py
import numpy
import numpy as np
import torch
try:
import msgspec
json_decode = msgspec.json.decode
json_encode = msgspec.json.encode
except ModuleNotFoundError:
import json
json_decode = json.loads
json_encode = json.dumps
import torch
from torch.utils.data import Dataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.bitset import bitset_diff
print_lock = threading.Lock()
def random_range(start, stop=None, step=None):
"""
Generator of non-repeated random permutation with the same inteface of python
`range`. Obtained from https://stackoverflow.com/a/53551417
The random.shuffle(list) and random.sample(list, len(list)) require
materialize the lists, which result in a long initalization period.
"""
if stop is None:
start, stop = 0, start
if step is None:
step = 1
# Use a mapping to convert a standard range into the desired range.
mapping = lambda i: (i * step) + start
# Compute the number of numbers in this range.
maximum = int(math.ceil((stop - start) / step))
if maximum == 0:
# early return with empty range
yield from ()
return
# Seed range with a random integer.
value = random.randint(0, maximum)
# Construct an offset, multiplier, and modulus for a linear
# congruential generator. These generators are cyclic and
# non-repeating when they maintain the properties:
#
# 1) "modulus" and "offset" are relatively prime.
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
# Pick a random odd-valued offset.
offset = random.randint(0, maximum) * 2 + 1
# Pick a multiplier 1 greater than a multiple of 4.
multiplier = 4 * (maximum // 4) + 1
# Pick a modulus just big enough to generate all numbers (power of 2).
modulus = int(2 ** math.ceil(math.log2(maximum)))
# Track how many random numbers have been returned.
found = 0
while found < maximum:
# If this is a valid value, yield it in generator fashion.
if value < maximum:
found += 1
yield mapping(value)
# Calculate the next value in the sequence.
value = (value * multiplier + offset) % modulus
class Range(object):
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
self.step = step
def __repr__(self):
return f"Range({self.start}, {self.stop}, {self.step})"
def iterate(self):
yield from range(self.start, self.stop, self.step)
def list(self):
return list(range(self.start, self.stop, self.step))
def subrange(self, split, nsplits):
# strided spliting range params
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
def random_iterate(self):
yield from random_range(self.start, self.stop, self.step)
def safe_print(*args, **kargs):
if "flush" in kargs:
flush = kargs["flush"]
del kargs["flush"]
else:
flush = True
with print_lock:
print(*args, **kargs, flush=flush)
def concurrent_info():
# world_size, rank = bmt.world_size(), bmt.rank()
world_size = bmt.config["world_size"] // bmt.config["tp_size"]
rank = bmt.config["topology"].tp_idx
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
nworkers, worker_id = 1, 1
else:
nworkers, worker_id = worker_info.num_workers, worker_info.id
# print("concurrent_info: (world_size, rank, nworkers, worker_id): {}".format((world_size, rank, nworkers, worker_id)))
return world_size, rank, nworkers, worker_id
class IndexedDataset(Dataset):
def __init__(self, path, max_retry=1, retry_sleep=5):
super().__init__()
self.path = path
self.max_retry = max_retry
self.retry_sleep = retry_sleep
self.bounds = None
self.h5file = None
self.build_index()
def size(self):
return self.bounds[-1]
def _build_index_h5(self):
index_path = os.path.join(self.path, "index.h5")
if os.path.getsize(index_path) > 104857600:
self.h5file = h5py.File(os.path.join(self.path, "index.h5"), "r")
self.bounds = self.h5file["index"]
else:
# only load index into memory when it is small (< 100 Mb)
# to avoid keeping to many file handlers
self.h5file = None
with h5py.File(index_path, "r") as hf:
self.bounds = np.array(hf["index"])
def __del__(self):
if self.h5file is not None:
self.h5file.close()
def build_index(self):
s = time.time()
txt_size = os.path.getsize(os.path.join(self.path, "index"))
if txt_size > 0.5 * 1024**3 and os.path.exists(os.path.join(self.path, "index.h5")):
source = "h5"
self._build_index_h5()
else:
source = "txt"
self._build_index_txt()
e = time.time()
bmt.print_rank("build_index_{} from {}, using {:.2f}s".format(source, self.path, e - s))
def _build_index_txt(self):
with open(os.path.join(self.path, "index"), "r") as fin:
self.bounds = [int(line) for line in fin]
self.nlines = len(self.bounds)
def safe_read(self, i_or_s, offset, size):
for retry in itertools.count():
try:
# destroy the file identifier to avoid pressure on alluxio
# buffering=0 to avoid overhead during file.seek() and open()
with open(os.path.join(self.path, "data.jsonl"), "rb", buffering=0) as fin:
fin.seek(offset)
raw = fin.read(size)
return raw
except OSError as e:
if retry >= self.max_retry:
raise OSError(f"reach maximum #retry: {retry}, the file system is broken.")
safe_print(
f"retry loading {self.path}:{i_or_s} in {self.retry_sleep} seconds due to error: '{repr(e)}'"
)
time.sleep(self.retry_sleep)
except ValueError as e:
# reading error during python io, skip
safe_print(f"skipping {self.path}:{i_or_s} due to error: '{repr(e)}'")
return None
def __repr__(self):
return (
f"IndexedDataset(path={self.path}, max_retry={self.max_retry}, retry_sleep={self.retry_sleep}) "
f"with {len(self)} entries."
)
def __len__(self):
return len(self.bounds) - 1
def bound_idx(self, key, strict=False):
# bound index within the standard range: [0, len(self))
# useful for tracing buggy entries
if strict and not (-len(self) <= key < len(self)):
raise IndexError(f"Index {key} out of range for '{self.path}'")
key = min(max(-len(self), key), len(self)) # bound key within [-len(self), len(self)]
key = key if key > 0 else key % len(self) # remap negative id to positive ones
return key
def __getitem__(self, key):
# supports list-like slicing and indexing. strided slicing is not currently supported.
# ok: self[1], self[-1], self[1:3], self[-10:-5], self[-10:-5:1], self[:5]
# not ok: self[-10:-5:2], self[:100:3]
if isinstance(key, slice):
if not (key.step == 1 or key.step is None):
raise ValueError(f"slice step should be 1 or None, not {key.step}")
start = self.bound_idx(0 if key.start is None else key.start)
stop = max(self.bound_idx(len(self) if key.stop is None else key.stop), start)
if stop == start:
# early returning empty slice
return list()
offset, size = self.bounds[start], self.bounds[stop] - self.bounds[start]
raw = self.safe_read(key, offset, size)
if raw is None:
return None
else:
return [
raw[s - offset : e - offset]
for s, e in zip(self.bounds[start:stop], self.bounds[start + 1 : stop + 1])
]
elif isinstance(key, int):
key = self.bound_idx(key, strict=True)
offset, size = self.bounds[key], self.bounds[key + 1] - self.bounds[key]
raw = self.safe_read(key, offset, size)
return raw
else:
raise TypeError(f"indices must be integers or slices, not {type(key)}")
class PrefetchDecodeDataset(IndexedDataset):
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
# Add safe decoding in iterator
def __init__(self, *args, decode=json_decode, allow_repeat=False, **kargs):
super().__init__(*args, **kargs)
self.decode = decode
self.allow_repeat = allow_repeat
def safe_decode(self, i, raw):
if raw is None:
return None
try:
return self.decode(raw)
except Exception as e:
safe_print(f"Skip decoding {self.path}:{i} due to error '{e}', raw bytes:\n{raw}")
return None
def __getitem__(self, key):
raw = super().__getitem__(key)
if raw is None:
return None
# key should be either a slice or an integer as checked in IndexedDataset
if isinstance(key, slice):
return [self.safe_decode(i, r) for i, r in zip(range(key.start, key.stop), raw)]
else:
return self.safe_decode(key, raw)
def loader(self, q, lid, keys, stop, used=None):
# concurrent prefetching worker
if used is None:
used = BitSet()
try:
for key in keys:
if stop.is_set():
break
# key is either a slice or an integer index
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
unused = bitset_diff(set(index), used)
if not unused:
# skip used slice / item
continue
if not q.empty():
# avoid breaking the distributed file system with large io load
time.sleep(random.random() * 2)
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
raw = super().__getitem__(key)
if raw is None:
continue
# filter used data
items = [(i, s) for i, s in zip(index, raw if len(index) > 1 else [raw]) if i in unused]
random.shuffle(items)
for item in items:
q.put(item)
finally:
# signaling the end of iteration to the main thread
q.put(StopIteration(lid))
def _iterate(self, key_groups, nprefetch=1000, used=None):
# helper function for concurrent prefetching
q = queue.Queue(maxsize=nprefetch)
stop = threading.Event()
alive = set()
try:
for lid, keys in enumerate(key_groups):
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop, used), daemon=True)
loader.start()
alive.add(lid)
while True:
try:
item = q.get(block=False)
except queue.Empty:
if not alive:
# no alive loader, thus no item will be put in the queue
break
else:
# new item will be put later, wait for a while
time.sleep(0.1)
continue
if isinstance(item, StopIteration):
alive.remove(item.value)
continue
i, raw = item
data = self.safe_decode(i, raw)
if data is None:
continue
yield i, data
finally:
# ask daemon loaders to stop
stop.set()
def iterate(self, nthreads=3, prefetch_sample=100, used=None, process_group=None):
world_size, rank, nworkers, worker_id = concurrent_info(process_group)
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
raise ValueError(
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
f"please constrain either "
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
)
r = Range(0, len(self), 1)
# split index among multi-gpu workers
r = r.subrange(split=rank, nsplits=world_size)
# split index among multi-process dataloader workers
r = r.subrange(split=worker_id, nsplits=nworkers)
# split index among multi-threaded loaders
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
return self._iterate(id_groups, nprefetch=prefetch_sample, used=used)
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=500, used=None):
world_size, rank, nworkers, worker_id = concurrent_info()
nloaders = world_size * nworkers * nthreads
if len(self) < nloaders:
if not self.allow_repeat:
raise ValueError(
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
f"please constrain either "
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
)
else:
duplicated_factor = math.ceil(nloaders / len(self))
# In this case, slice size is 1
r = Range(0, len(self), 1)
# split index among grouped multi-gpu workers
r = r.subrange(split=rank // duplicated_factor, nsplits=math.ceil(world_size / duplicated_factor))
# # split index among multi-threaded loaders
r = r.subrange(split=worker_id, nsplits=nworkers)
else:
nslices = int(math.ceil(len(self) / slice_size))
if nslices < nloaders:
safe_print(
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
)
slice_size = len(self) // nloaders
# we only iteratre through start ids as they uniquely mark each slice
r = Range(0, len(self), slice_size)
# split index among multi-gpu workers
r = r.subrange(split=rank, nsplits=world_size)
# split index among multi-process dataloader workers
r = r.subrange(split=worker_id, nsplits=nworkers)
# split index among multi-threaded loaders
slice_groups = [
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
]
return self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size, used=used)
class IndexedDatasetBuilder:
def __init__(self, path, overwrite=False):
self.path = path
self.index_path = os.path.join(self.path, "index.h5")
self.index_path_txt = os.path.join(self.path, "index")
self.data_path = os.path.join(self.path, "data.jsonl")
if not overwrite:
assert not os.path.exists(self.data_path)
assert not os.path.exists(self.index_path)
assert not os.path.exists(self.index_path_txt)
self.fout = None
self.bounds = []
self.offset = 0
def __enter__(self):
os.makedirs(self.path, exist_ok=True)
self.fout = open(self.data_path, "wb")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.bounds.append(self.offset)
with h5py.File(os.path.join(self.index_path), "w") as hf:
hf.create_dataset("index", data=self.bounds)
with open(self.index_path_txt, "w") as fout_txt:
for s in self.bounds:
fout_txt.write(f"{s}\n")
self.fout.close()
def put(self, data: dict):
s = json_encode(data) + b"\n"
self.bounds.append(self.offset)
self.offset += len(s)
self.fout.write(s)
if __name__ == "__main__":
with IndexedDatasetBuilder("swear", overwrite=True) as builder:
for d in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
builder.put(d)
dataset = IndexedDataset("swear")
for i in range(10):
print(dataset[random.randint(0, len(dataset) - 1)])

View File

@ -1,61 +0,0 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pickle
class Serializer:
def __init__(self) -> None:
pass
def serialize(self, obj) -> bytes:
raise NotImplementedError()
def deserialize(self, data: bytes):
raise NotImplementedError()
class PickleSerializer(Serializer):
def __init__(self) -> None:
pass
def serialize(self, obj) -> bytes:
return pickle.dumps(obj)
def deserialize(self, data: bytes):
return pickle.loads(data)
class JsonSerializer(Serializer):
def __init__(self) -> None:
pass
def serialize(self, obj) -> bytes:
return json.dumps(obj, ensure_ascii=False).encode("utf-8")
def deserialize(self, data: bytes):
return json.loads(data.decode("utf-8"))
class RawSerializer(Serializer):
def __init__(self) -> None:
pass
def serialize(self, obj) -> bytes:
return obj
def deserialize(self, data: bytes):
return data

View File

@ -1,7 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/08/07

View File

@ -1,31 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/07/27
import argparse
import torch
from tqdm import tqdm
from fm9g.dataset import SimpleDataset
from fm9g.dataset.indexed_dataset import IndexedDatasetBuilder
def convert_fm9g_data(fm9g_path, out_path):
dataset = SimpleDataset(fm9g_path, shuffle=False)
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
builder.put(dataset.read())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, help="Data path in fm9g format.")
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
args = parser.parse_args()
convert_fm9g_data(args.input, args.output)

View File

@ -1,31 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/08/07
import argparse
import os
def build_index(path):
data_path = os.path.join(path, "data.jsonl")
assert os.path.exists(data_path), f"Jsonline dataset '{data_path}' not found."
offset = 0
starts = [offset]
with open(data_path, "rb") as fin:
for line in fin:
offset += len(line)
starts.append(offset)
with open(os.path.join(path, "index"), "w") as fout:
for s in starts:
fout.write(f"{s}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--path", "-p", required=True, help="Data path.")
args = parser.parse_args()
build_index(args.path)

View File

@ -1,436 +0,0 @@
# coding=utf-8
# Copyright 2020 The OpenBMB team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os
import random
import shutil
import struct
from queue import Queue
from threading import Thread
from typing import Iterable
from typing import List
from typing import Optional
import torch
from ..utils.log import logger
from .distributed_dataset import _DEFAULT_BLOCK_SIZE
from .distributed_dataset import _random_string
from .distributed_dataset import _read_info_list
from .distributed_dataset import _write_info_list
from .distributed_dataset import build_dataset
from .distributed_dataset import FileInfo
from .distributed_dataset import SimpleDataset
from .serializer import RawSerializer
try:
from tqdm import tqdm
support_tqdm = True
except ModuleNotFoundError:
support_tqdm = False
_DEFAULT_SHUFFLE_BUCKET_SIZE = 1 << 30
def shuffle_dataset(
path_src: str,
path_tgt: str,
block_size: int = _DEFAULT_BLOCK_SIZE,
bucket_size: int = _DEFAULT_SHUFFLE_BUCKET_SIZE,
progress_bar: bool = False,
output_name: Optional[str] = None,
):
"""Shuffle one distributed datataset, write results to another dataset.
Args:
path_str (str): path to source dataset
path_tgt (str): path to write results
block_size (int): dataset block size (default: 16MB)
bucket_size (int): shuffle algorithm bucket size (default: 1GB)
progress_bar (bool): show progress bar
Example:
>>> shuffle_dataset("/path/to/source", "/path/to/output")
"""
if progress_bar and not support_tqdm:
raise RuntimeError("Requires `tqdm` to enable progress bar.")
ds = SimpleDataset(path_src, serializer=RawSerializer())
num_buckets = (ds.nbytes + bucket_size - 1) // bucket_size
tmp_files = [os.path.join(path_src, ".tmp.%s" % _random_string()) for _ in range(num_buckets)]
try:
# Step 1: write to bucket randomly
f_tmp = [open(fname, "wb") for fname in tmp_files]
try:
iterator = ds
if progress_bar:
iterator = tqdm(ds, desc="Shuffle step 1/2")
for data in iterator:
bucket_id = int(random.random() * num_buckets)
len_data = len(data)
f_tmp[bucket_id].write(struct.pack("I", len_data) + data)
finally:
# close all files
for fp in f_tmp:
if not fp.closed:
fp.close()
f_tmp = []
# Step 2: shuffle inside bucket
if output_name is None:
output_name = "%s.shuffle" % _random_string()
with build_dataset(
path_tgt,
output_name,
block_size=block_size,
serializer=RawSerializer(),
) as writer:
iterator = tmp_files
if progress_bar:
iterator = tqdm(tmp_files, desc="Shuffle step 2/2")
for fname in iterator:
fp = open(fname, "rb")
data_in_bucket = []
while True:
try:
raw_data = fp.read(4)
if len(raw_data) == 0:
# EOF
break
len_data = struct.unpack("I", raw_data)[0]
data_in_bucket.append(fp.read(len_data))
except EOFError:
break
random.shuffle(data_in_bucket)
for data in data_in_bucket:
writer.write(data)
fp.close()
os.unlink(fname)
finally:
# cleanup
for fname in tmp_files:
if os.path.exists(fname):
os.unlink(fname)
def compact_dataset(path: str):
"""Compact the dataset, removes blocks which the files were deleted.
**Note** This may affect the existing dataset state dict.
Args:
path (str): path to dataset
Example:
>>> compact_dataset("/path/to/dataset")
"""
meta_path = os.path.join(path, "meta.bin")
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
else:
raise ValueError("Dataset not exists")
nw_info: List[FileInfo] = []
curr_block = 0
for v in info:
if not os.path.exists(v.file_name):
# file is deleted
pass
else:
num_file_block = v.block_end - v.block_begin
nw_info.append(
FileInfo(
v.file_name,
curr_block,
curr_block + num_file_block,
v.nbytes,
v.nlines,
v.mask,
v.block_size,
)
)
curr_block += num_file_block
_write_info_list(meta_path, nw_info)
def mask_dataset(path: str, dbname: str, mask: bool = True):
"""Mask one file in dataset. Blocks in masked datasets won't be read later.
Args:
path (str): path to dataset
dbname (str): file name in this dataset which you want to mask
mask (bool): True for mask, False for unmask
Example:
>>> mask_dataset("/path/to/dataset", "data_part_1", mask=True)
"""
meta_path = os.path.join(path, "meta.bin")
info: List[FileInfo] = []
if os.path.exists(meta_path):
info = _read_info_list(meta_path)
else:
raise ValueError("Dataset not exists")
for v in info:
if v.file_name == dbname:
v.mask = mask
_write_info_list(meta_path, info)
def merge_dataset(dst: str, src: str):
meta_path_src = os.path.join(src, "meta.bin")
meta_path_dst = os.path.join(dst, "meta.bin")
info_src: List[FileInfo] = []
if os.path.exists(meta_path_src):
info_src = _read_info_list(meta_path_src)
else:
raise ValueError("Dataset not exists")
info_dst: List[FileInfo] = []
if os.path.exists(meta_path_dst):
info_dst = _read_info_list(meta_path_dst)
else:
raise ValueError("Dataset not exists")
curr_block = 0
nw_info: List[FileInfo] = []
for v in info_dst:
num_file_block = v.block_end - v.block_begin
nw_info.append(
FileInfo(
v.file_name,
curr_block,
curr_block + num_file_block,
v.nbytes,
v.nlines,
v.mask,
v.block_size,
)
)
curr_block += num_file_block
for v in info_src:
num_file_block = v.block_end - v.block_begin
dst_db_name = os.path.join(dst, v.file_name)
nw_fname = v.file_name
if os.path.exists(dst_db_name):
idx = 0
while os.path.exists(dst_db_name + "_{}".format(idx)):
idx += 1
dst_db_name = dst_db_name + "_{}".format(idx)
nw_fname = nw_fname + "_{}".format(idx)
shutil.copy(os.path.join(src, v.file_name), dst_db_name)
nw_info.append(
FileInfo(
nw_fname,
curr_block,
curr_block + num_file_block,
v.nbytes,
v.nlines,
v.mask,
v.block_size,
)
)
curr_block += num_file_block
_write_info_list(meta_path_dst, nw_info)
def to_fm9g(src_data, dst_path, dst_name):
if not os.path.exists(dst_path):
os.makedirs(dst_path)
logger.info(f"src_data: {src_data}")
logger.info(f"dst_path: {dst_path}")
logger.info(f"dst_name: {dst_name}")
tmp_dst_path = dst_path.rstrip("/") + "_tmp"
if not os.path.exists(tmp_dst_path):
os.makedirs(tmp_dst_path)
logger.info(f"write binary into: {tmp_dst_path}")
with build_dataset(tmp_dst_path, dst_name) as dataset:
if os.path.isdir(src_data):
filenames = [os.path.join(src_data, name) for name in os.listdir(src_data)]
else:
filenames = [src_data]
n_filenames = len(filenames)
for idx, filename in enumerate(filenames):
logger.info(f"deal: [{n_filenames} -> {idx}] {filename}")
if not os.path.exists(filename):
logger.error(f"not exist: {filename}")
continue
with open(filename, "r", encoding="utf-8") as fin:
for line in fin:
line = line.strip()
dataset.write(json.loads(line))
logger.info(f"shuffle binary data from {tmp_dst_path} to {dst_path}")
shuffle_dataset(tmp_dst_path, dst_path, progress_bar=True, output_name=dst_name)
if os.path.exists(tmp_dst_path):
shutil.rmtree(tmp_dst_path)
def random_range(start, stop=None, step=None):
"""
Generator of non-repeated random permutation with the same inteface of python
`range`. Obtained from https://stackoverflow.com/a/53551417
The random.shuffle(list) and random.sample(list, len(list)) require
materialize the lists, which result in a long initalization period.
"""
if stop is None:
start, stop = 0, start
if step is None:
step = 1
# Use a mapping to convert a standard range into the desired range.
mapping = lambda i: (i * step) + start
# Compute the number of numbers in this range.
maximum = int(math.ceil((stop - start) / step))
if maximum == 0:
# early return with empty range
yield from ()
return
# Seed range with a random integer.
value = random.randint(0, maximum)
# Construct an offset, multiplier, and modulus for a linear
# congruential generator. These generators are cyclic and
# non-repeating when they maintain the properties:
#
# 1) "modulus" and "offset" are relatively prime.
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
# Pick a random odd-valued offset.
offset = random.randint(0, maximum) * 2 + 1
# Pick a multiplier 1 greater than a multiple of 4.
multiplier = 4 * (maximum // 4) + 1
# Pick a modulus just big enough to generate all numbers (power of 2).
modulus = int(2 ** math.ceil(math.log2(maximum)))
# Track how many random numbers have been returned.
found = 0
while found < maximum:
# If this is a valid value, yield it in generator fashion.
if value < maximum:
found += 1
yield mapping(value)
# Calculate the next value in the sequence.
value = (value * multiplier + offset) % modulus
class Range(object):
def __init__(self, start, stop, step):
self.start = start
self.stop = stop
self.step = step
def __repr__(self):
return f"Range({self.start}, {self.stop}, {self.step})"
def iterate(self):
yield from range(self.start, self.stop, self.step)
def list(self):
return list(range(self.start, self.stop, self.step))
def subrange(self, split, nsplits):
# strided spliting range params
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
def random_iterate(self):
yield from random_range(self.start, self.stop, self.step)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
data = self.data
self.preload()
return data
def __iter__(self):
return self
class ThreadedPrefetcher(Thread):
def __init__(self, iterable, prefetch=10):
"""
Wrap around a data iterator to shield io latency with a daemon thread.
"""
super(ThreadedPrefetcher, self).__init__()
self.queue = Queue(maxsize=prefetch)
self.iterable = iterable
self.daemon = True
self.start()
def run(self):
try:
for data in self.iterable:
self.queue.put(data)
except Exception as exception:
self.queue.put(exception)
finally:
self.queue.put(StopIteration())
def __next__(self):
item = self.queue.get()
if isinstance(item, Exception):
raise item
else:
return item
def __iter__(self):
return self

View File

@ -1,8 +0,0 @@
{
"folders": [
{
"path": "../.."
}
],
"settings": {}
}

View File

@ -1,105 +0,0 @@
import torch
from transformers.configuration_utils import PretrainedConfig
class DragonflyConfig(PretrainedConfig):
model_type = "fm9g_dragonfly"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_key_value_heads": "num_kv_heads",
"hidden_act": "activate_fn",
"hidden_size": "dim_model",
"num_attention_heads": "num_heads",
"intermediate_size": "dim_ff",
"num_hidden_layers": "num_layers",
"vocab_size": "vocab_size",
"rms_norm_eps": "eps",
"scale_emb": "scale_emb",
"scale_depth": "scale_depth",
"scale": "scale",
"attention_scale": "attention_scale",
"qk_norm": "qk_norm",
"ffn_gated": "ffn_gated",
} # model specific to common
def __init__(
self,
vocab_size=122753, # TODO: do we need to change to 122880 = 960 * 128?
dim_model=4096,
num_heads=32,
num_kv_heads=32,
dim_head=128,
dim_ff=11008,
num_layers=32,
dropout_p=0.0,
activate_fn="silu",
scale=False,
scale_emb: float = 1.0,
scale_depth: float = -1,
dim_model_base: int = 256,
eps=1e-5,
init_std=0.02,
dtype="bf16",
base=10000,
qk_norm=False,
tie_lm_head=False,
max_length=8192,
pose_prob=0.0,
pose_scaling_factor=1,
rope_scaling_type="",
rope_scaling_factor=1,
orig_max_length=8192,
tp=0,
use_checkpoint=True,
**kwargs,
):
self.vocab_size = vocab_size
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.dim_head = dim_head
self.dim_ff = dim_ff
self.num_layers = num_layers
self.dropout_p = dropout_p
self.activate_fn = activate_fn
self.scale = scale
self.scale_emb = scale_emb
self._dtype = dtype
self.dim_model_base = dim_model_base
self.scale_depth = scale_depth
self.eps = eps
self.init_std = init_std
self.base = base
self.qk_norm = qk_norm
self.tie_lm_head = tie_lm_head
self.use_bfloat16 = True if self._dtype == "bf16" else False
self.pose_prob = pose_prob
self.pose_scaling_factor = pose_scaling_factor
self.rope_scaling_type = rope_scaling_type
self.rope_scaling_factor = rope_scaling_factor
self.max_length = max_length
self.orig_max_length = orig_max_length
self.use_checkpoint = use_checkpoint
print("use_checkpoint", self.use_checkpoint)
self.tp = tp
super().__init__(architectures=["fm9gDragonflyForCausalLM"])
@property
def scale_width(
self,
):
if self.scale:
return self.dim_model / self.dim_model_base
else:
return 1.0
@property
def dtype(
self,
): # -> Any | None:
if self._dtype == "bf16":
return torch.bfloat16
elif self._dtype == "fp16":
return torch.half
elif self._dtype == "float32":
return torch.float

File diff suppressed because it is too large Load Diff

View File

@ -1 +0,0 @@
from .pretrain_indexed import MixedIndexedDataset

View File

@ -1,74 +0,0 @@
import logging
from multiprocessing import Lock
from flask import Flask
from flask import jsonify
from flask import request
app = Flask(__name__)
# 获取 Werkzeug 日志记录器并设置日志级别
log = logging.getLogger("werkzeug")
log.setLevel(logging.WARNING)
class GlobalAvgTokensStat(object):
def __init__(self, decay_factor: float = 0.98):
self._avg_tokens = {}
self.decay_factor = decay_factor
self.lock = Lock()
self.task_locks = {}
def set_avg_tokens(self, task_name, avg_tokens):
self._register_task_lock_helper(task_name)
with self.task_locks[task_name]:
self._avg_tokens[task_name] = avg_tokens
def update_avg_tokens_by_ema(self, task_name, length):
self._register_task_lock_helper(task_name)
with self.task_locks[task_name]:
if task_name in self._avg_tokens and self._avg_tokens[task_name] > 0:
self._avg_tokens[task_name] = self._avg_tokens[task_name] * self.decay_factor + length * (
1 - self.decay_factor
)
else:
self._avg_tokens[task_name] = length
def get_avg_tokens(self, task_name):
self._register_task_lock_helper(task_name)
with self.task_locks[task_name]:
return self._avg_tokens.get(task_name, -1)
def _register_task_lock_helper(self, task_name):
if task_name not in self.task_locks:
with self.lock:
if task_name not in self.task_locks:
self.task_locks[task_name] = Lock()
global_avg_tokens_stat = GlobalAvgTokensStat()
@app.route("/avg_tokens/<path:task_name>", methods=["GET"])
def get_avg_tokens(task_name):
global global_avg_tokens_stat
avg_tokens = global_avg_tokens_stat.get_avg_tokens(task_name)
return jsonify({"avg_tokens": avg_tokens})
@app.route("/avg_tokens/<path:task_name>", methods=["POST"])
def set_avg_tokens(task_name):
global global_avg_tokens_stat
action = request.args.get("action", "update", type=str)
length = request.args.get("length", -1, type=int)
if action == "set":
global_avg_tokens_stat.set_avg_tokens(task_name, length)
elif action == "update":
global_avg_tokens_stat.update_avg_tokens_by_ema(task_name, length)
else:
raise ValueError(f"Unknown action: {action}")
return jsonify({"status": "ok"})
if __name__ == "__main__":
app.run(port=5000, debug=True)

View File

@ -1,828 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
import copy
import ctypes
import functools
import importlib
import json
import logging
import os
import random
from collections import defaultdict
from collections import OrderedDict
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.shared_memory import SharedMemory
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from numpy.typing import NDArray
from fm9g.dataset import PrefetchDecodeDataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.vdc_sampling import van_der_corput
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
logger = logging.getLogger(__name__)
IGNORE_TGT = -100
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
if cfg_json_str is not None:
cfgs = json.loads(cfg_json_str)
else:
with open(cfg_path, "r", encoding="utf-8") as fin:
cfgs = json.load(fin)
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
path_dict = None
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
try:
with open(platform_config_path, "r") as f:
platform_cfg = json.load(f)
path_dict = platform_cfg["dataset_map"]
if bmt.rank() == 0:
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
except Exception as e:
if bmt.rank() == 0:
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
task_name2dataset_name = dict()
for idx, cfg in enumerate(cfgs):
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
# to be delibrately annoying :)
if cfg["task_name"] in task_name2dataset_name:
raise ValueError(
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
)
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
assert "path" in cfg and isinstance(cfg["path"], str)
# if path_dict is not None:
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
# dealing with optional configs
if "weight" in cfg:
assert isinstance(cfg["weight"], (float, int))
else:
cfg["weight"] = 1.0
if "oversize_rule" in cfg:
assert cfg["oversize_rule"] in ("drop", "head", "segment")
else:
cfg["oversize_rule"] = "segment"
if "transforms" in cfg:
assert isinstance(cfg["transforms"], str)
# dealing with relative path
if not cfg["transforms"].startswith("/"):
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
if not cfg["transforms"]:
cfg["transforms"] = None
else:
cfg["transforms"] = None
if "incontext_weight" in cfg:
assert isinstance(cfg["incontext_weight"], (list, tuple))
else:
cfg["incontext_weight"] = [1.0]
cfg["id"] = idx
# dataset and iterator will be built
return cfgs
def data2ids(data, tokenizer, max_length):
text = "\n".join(
[
data.get("title", "").strip(),
data.get("question", "").strip(),
data.get("answer", "").strip(),
data.get("abstract", "").strip(),
data.get("text", "").strip(),
data.get("code", "").strip(),
]
).strip()
if not text:
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
yield from ()
return
# suppress the annoying warning from tokenizer
ids = (
[tokenizer.bos_token_id]
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
+ [tokenizer.eos_token_id]
)
src_ids = ids[0:-1]
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
if len(src_ids) > max_length:
for st in range(0, len(src_ids), max_length):
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
else:
yield src_ids, tgt_ids
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
assert oversize_rule in ("drop", "head", "segment")
if data is None:
yield from ()
return
if "output" not in data or not data["output"]:
yield from ()
return
if "input" not in data or data["input"] is None:
data["input"] = ""
src_ids = [tokenizer.bos_token_id]
tgt_ids = []
has_input = False
is_segment_reenter = False
# Use incremental tokenization to avoid waiting for a long document
MAX_CHUNK_LENGTH = max_length * 10
for part in ("input", "output"):
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
while l < len(data[part]):
try:
current_slice = data[part][l:r]
if not current_slice:
break
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
except:
print("Error in data[part][l:r] {}".format(data))
yield from ()
return
if part == "input":
if len(token_ids) > 0:
has_input = True
if len(token_ids) >= max_length - 2: # input len must < max_length
yield from ()
return
src_ids.extend(token_ids)
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
l = r
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
else:
if len(token_ids) + len(tgt_ids) >= max_length:
if oversize_rule == "drop":
yield from ()
return
elif oversize_rule == "head":
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
elif oversize_rule == "segment":
instruction_rest_space = max_length - 1 - len(token_ids)
if has_input: # is instruction data
if (
do_compact
and len(src_ids) >= 128 # avoid too short instruction info lost
and instruction_rest_space / len(src_ids) > 0.8
): # can be squeezed into max length
inputs_len = len(src_ids)
keep_len = instruction_rest_space // 2
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else: # else use head rule
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
else: # normal segment
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids)
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
yield src_ids[:max_length], tgt_ids[:max_length]
src_ids = src_ids[max_length:]
tgt_ids = tgt_ids[max_length:]
# sliding input str window
consumed_str = tokenizer.decode(selected_token_ids)
l += len(consumed_str)
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
is_segment_reenter = True
else:
if (is_segment_reenter and len(token_ids) > 8) or (
not is_segment_reenter and len(token_ids) > 0
): # is segmented LM data
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else:
yield from ()
return
class SegmentedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg,
tokenizer,
max_length=1024,
transform_func=None,
nthreads=1,
prefetch_slice=3,
slice_size=500,
do_compact=False,
):
super(SegmentedDataset, self).__init__()
self.segment = functools.partial(
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
)
self.cfg = cfg
self.max_length = max_length
self.nthreads = nthreads
self.transform_func = transform_func
self.prefetch_slice = prefetch_slice
self.slice_size = slice_size
self.abs_weight = cfg.get("abs_weight", None)
self.task_name = cfg["task_name"]
self.dataset_name = cfg["dataset_name"]
self.oversize_rule = cfg["oversize_rule"]
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
self.exhausted = False
self.iterator = None
self.counter = 0
self.allow_repeat = cfg.get("allow_repeat", True)
self.used = BitSet()
self.init_ave_tokens()
def init_ave_tokens(
self,
):
try:
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
except FileNotFoundError:
bmt.print_rank(
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
)
shm = SharedMemory(
create=True,
size=ctypes.sizeof(ctypes.c_float),
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
)
# 使用共享内存
shared_value = ctypes.c_float.from_buffer(shm.buf)
_ave_tokens = self.cfg.get(
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
)
if _ave_tokens > self.max_length:
_ave_tokens = self.max_length
bmt.print_rank(
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
_ave_tokens, self.max_length
)
)
shared_value.value = _ave_tokens
# 不再需要 shared_value 时,删除引用
del shared_value
# 现在可以安全地关闭共享内存
shm.close()
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
@property
def ave_tokens(
self,
):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
tmp = shared_value.value
del shared_value
existing_shm.close()
return tmp
def ave_tokens_update(self, length):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
if shared_value.value < 0:
shared_value.value = float(length)
else:
shared_value.value = 0.98 * shared_value.value + 0.02 * length
del shared_value
existing_shm.close()
def size(self):
return self.dataset.size()
def __iter__(self):
self.iterate()
return self
def reset(self):
self.exhausted = False
if self.iterator is not None:
self.iterator.close()
self.iterator = None
self.used = BitSet()
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
def transform(self, data: dict) -> dict:
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
weight = weight / weight.sum()
num_incontext = np.random.choice(weight.shape[0], p=weight)
return self.transform_func(data, num_incontext, random.Random())
def segment_iterate(self, sample_iter):
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
for src_ids, tgt_ids in self.segment(self.transform(data)):
self.ave_tokens_update(len(src_ids)) # 0 for input ids
yield src_ids, tgt_ids, index
def iterate(self):
# make the dataset itself an iterator
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
self.iterator = self.segment_iterate(sample_iter)
def __next__(self):
# advance the task iterator
if self.iterator is None:
self.iterate()
try:
return next(self.iterator)
except StopIteration:
self.exhausted = True
return None
def load_state_dict(self, state_dict):
if state_dict.get("exhausted", False):
self.exhausted = True
self.used = BitSet()
else:
used = state_dict.get("used", BitSet())
if len(used) == len(self.dataset):
self.exhausted = True
self.used = BitSet()
else:
self.exhausted = False
self.used = used
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
def state_dict(self):
if len(self.used) == len(self.dataset):
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
else:
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
def update_state(self, indice):
self.used.update(indice)
class MixedIndexedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg_path: str,
cfg_json_str,
tokenizer,
max_length,
weight_by_size: bool = True,
nthreads=5,
prefetch_slice=100,
parallel_loading=False,
vdc_sampling=False,
update_weights_frequency=1,
seed=42,
):
super(MixedIndexedDataset, self).__init__()
self.set_seed(seed + bmt.rank())
self.weight_by_size = weight_by_size
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.bos_token_id = self.tokenizer.bos_token_id
self.path2transform = dict()
self.task_dict = OrderedDict()
self.nthreads = nthreads
self.prefetch_slice = prefetch_slice
# useful for indexing
self.tasks = []
self.names = []
# ending of iteration
self.remain = 0
self.max_length = max_length
self.vdc_sampling = vdc_sampling
if self.vdc_sampling:
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
self.update_weights_frequency = update_weights_frequency
self.path2transform = dict()
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
if parallel_loading:
self.parallel_load(cfgs, max_workers=None)
else:
self.sequential_load(cfgs)
self.weights = None
self.update_weights()
def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_task(self, cfg):
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
return task
def sequential_load(self, cfgs):
self.cfgs = cfgs
for cfg in cfgs:
# python3.7 and later preserves insertion order to dictionary
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
self.task_dict[task.task_name] = task
self.tasks.append(task)
self.names.append(task.task_name)
self.remain += 1
self.weights = None
self.update_weights()
def load_state_dict(self, state_dict):
missing_keys = []
for name, task in self.task_dict.items():
if name in state_dict:
task.load_state_dict(state_dict[name])
else:
missing_keys.append(name)
self.update_weights()
return missing_keys
def save_state_dict(self, path):
state_dict = {}
for name, task in self.task_dict.items():
_state_dict = task.state_dict()
if isinstance(_state_dict["used"], BitSet):
bitset = _state_dict["used"]
_file_name = bitset.save(path)
_state_dict["used"] = _file_name
state_dict[name] = _state_dict
else:
state_dict[name] = task.state_dict()
torch.save(state_dict, path)
logger.info("Dataset state saved")
def update_states(self, task_ids, indice):
is_dict = isinstance(indice, dict)
uniq = torch.unique(task_ids)
for idx in uniq:
idx = idx.item()
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
self.tasks[idx].update_state(indexes)
def get_transform_func(self, module_name: str, transform_script_path):
if transform_script_path is None:
# allow null transform
return lambda data, num_incontext, rand: data
module_name = "fm9g_live.transforms.{}".format(module_name)
if transform_script_path not in self.path2transform:
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
spec = importlib.util.spec_from_loader(loader.name, loader)
if spec is None:
raise RuntimeError("Spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self.path2transform[transform_script_path] = {
"loader": loader,
"module": mod,
"last_mtime": 0,
}
transform_script_info = self.path2transform[transform_script_path]
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
if curr_mtime > transform_script_info["last_mtime"]:
transform_script_info["last_mtime"] = curr_mtime
transform_script_info["loader"].exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
return transform_func
def update_weights(self):
task0 = self.tasks[0]
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
weights = []
for task in self.tasks:
if task.exhausted:
weights.append(0)
else:
if task.ave_tokens == -1:
weights.append(task.abs_weight / self.max_length)
else:
weights.append(task.abs_weight / task.ave_tokens)
weights = np.array(weights)
else:
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
if self.weight_by_size:
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
weights *= sizes
self.weights = weights / weights.sum()
def __iter__(self):
for task in self.tasks:
task.iterate()
return self
def __next__(self):
step = 1
while True:
if self.remain == 0:
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
raise StopIteration
if self.vdc_sampling:
idx = next(self.vdc_gen)(self.weights)
else:
idx = np.random.choice(len(self.weights), p=self.weights)
data = next(self.tasks[idx])
if data is None:
if self.tasks[idx].allow_repeat:
# _runtime_ave = self.tasks[idx].ave_tokens
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
# self.tasks[idx] = SegmentedDataset(
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
# )
# self.tasks[idx].ave_tokens_update(_runtime_ave)
self.tasks[idx].reset()
else:
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
self.tasks[idx].exhaust = True
self.remain -= 1
continue
if step % self.update_weights_frequency == 0:
self.update_weights()
step += 1
return dict(
task_id=idx,
input=data[0],
target=data[1],
index=data[2],
is_long=self.tasks[idx].cfg.get("is_long", False),
)
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
self.max_total_length = batch_size * max_length
self.batch_size = 1
# setting compact=True concats segments orignated from the same input
# into a long sequence. the relative order of segments should be preserved
# in mixed_dataset, e.g.,
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
self.compact = compact
self.total_length = 0
self.task2seqs = defaultdict(list)
self.mixed_dataset = mixed_dataset
self._max_length = max_length
self._pose_prob = pose_prob
self._pose_scaling_factor = pose_scaling_factor
if self._pose_prob > 0.0:
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
else:
self._scaled_max_length = max_length
def put(self, sample):
self.total_length += len(sample["target"])
task_id = sample["task_id"]
if self.compact and self.task2seqs[task_id]:
last = self.task2seqs[task_id][-1]
if last["target"][-1] != self.mixed_dataset.eos_token_id:
# concatenate sequantial segments for longer context modeling: why not?
last["input"].extend(sample["input"])
last["target"].extend(sample["target"])
return
self.task2seqs[task_id].append(sample)
def _pose_preprocess(
self,
input_ids: NDArray[np.int32],
) -> NDArray[np.int32]:
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
"""
len_chunk = min(len(input_ids), self._max_length)
len_input = len(input_ids)
# Chunk input randomly to fit max_length if needed
lt1 = 0
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
# Generate PoSE position ids
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
len_position_ids = len(position_ids)
lt = 0
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
position_ids[: rt1 - lt1] += lt
position_ids[rt1 - lt1 :] += rt
return position_ids
def pop(self):
indexes = defaultdict(list)
lengths = []
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
span_begin = 0
for samples in self.task2seqs.values():
while samples:
sample = samples.pop()
span_end = span_begin + len(sample["input"])
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
_span_position_ids = self._pose_preprocess(sample["input"])
else:
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
lengths.append(len(sample["target"]))
indexes[int(sample["task_id"])].append(sample["index"])
self.total_length -= len(sample["target"])
span_begin = span_end
cu_seqlens = torch.cat(
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
dim=0,
).int()
batch = {
"inputs": inputs,
"targets": targets,
"task_ids": task_ids,
"indexes": indexes,
# adhere to flash attention interface
"cu_seqlens": cu_seqlens,
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
"lengths": torch.tensor(sum(lengths)).int(),
"task_names": self.mixed_dataset.names,
"position_ids": position_ids,
}
return batch
def will_be_full(self, sample):
return self.total_length + len(sample["target"]) > self.max_total_length
def __iter__(self):
for sample in self.mixed_dataset:
if self.will_be_full(sample):
yield self.pop()
self.put(sample)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader, tp_size=1, tp_rank=0):
self.loader = iter(loader)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
if self.tp_size > 1:
if self.tp_rank == 0:
data = next(self.loader)
print("Rank {}, Preload data done.".format(bmt.rank()))
d = {}
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
for key in data.keys():
if isinstance(data[key], torch.Tensor):
np_cur_data = data[key].cpu().numpy()
bs = np_cur_data.tobytes()
fb.write(bs)
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
elif isinstance(data[key], np.ndarray):
bs = data[key].tobytes()
fb.write(bs)
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
else:
d[key] = data[key]
try:
_ = json.dumps(d)
except TypeError:
print(d)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
json.dump(d, f)
bmt.synchronize()
if self.tp_rank != 0:
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
data = json.load(f)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
bs = fb.read()
offset = 0
for key in data.keys():
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
nw_offset = offset + data[key][2]
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
data[key][3:]
)
offset = nw_offset
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
nw_offset = offset + data[key][2]
data[key] = torch.from_numpy(
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
.reshape(data[key][3:])
.copy()
)
offset = nw_offset
self.data = data
else:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key].record_stream(torch.cuda.current_stream())
data = copy.deepcopy(self.data)
self.preload()
return data
def __iter__(self):
return self

View File

@ -1,829 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
import copy
import ctypes
import functools
import importlib
import json
import logging
import os
import random
from collections import defaultdict
from collections import OrderedDict
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.shared_memory import SharedMemory
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from numpy.typing import NDArray
from fm9g.dataset import PrefetchDecodeDataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.vdc_sampling import van_der_corput
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
logger = logging.getLogger(__name__)
IGNORE_TGT = -100
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
if cfg_json_str is not None:
cfgs = json.loads(cfg_json_str)
else:
with open(cfg_path, "r", encoding="utf-8") as fin:
cfgs = json.load(fin)
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
path_dict = None
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
try:
with open(platform_config_path, "r") as f:
platform_cfg = json.load(f)
path_dict = platform_cfg["dataset_map"]
if bmt.rank() == 0:
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
except Exception as e:
if bmt.rank() == 0:
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
task_name2dataset_name = dict()
for idx, cfg in enumerate(cfgs):
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
# to be delibrately annoying :)
if cfg["task_name"] in task_name2dataset_name:
raise ValueError(
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
)
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
assert "path" in cfg and isinstance(cfg["path"], str)
# if path_dict is not None:
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
# dealing with optional configs
if "weight" in cfg:
assert isinstance(cfg["weight"], (float, int))
else:
cfg["weight"] = 1.0
if "oversize_rule" in cfg:
assert cfg["oversize_rule"] in ("drop", "head", "segment")
else:
cfg["oversize_rule"] = "segment"
if "transforms" in cfg:
assert isinstance(cfg["transforms"], str)
# dealing with relative path
if not cfg["transforms"].startswith("/"):
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
if not cfg["transforms"]:
cfg["transforms"] = None
else:
cfg["transforms"] = None
if "incontext_weight" in cfg:
assert isinstance(cfg["incontext_weight"], (list, tuple))
else:
cfg["incontext_weight"] = [1.0]
cfg["id"] = idx
# dataset and iterator will be built
return cfgs
def data2ids(data, tokenizer, max_length):
text = "\n".join(
[
data.get("title", "").strip(),
data.get("question", "").strip(),
data.get("answer", "").strip(),
data.get("abstract", "").strip(),
data.get("text", "").strip(),
data.get("code", "").strip(),
]
).strip()
if not text:
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
yield from ()
return
# suppress the annoying warning from tokenizer
ids = (
[tokenizer.bos_id]
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
+ [tokenizer.eos_id]
)
src_ids = ids[0:-1]
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
if len(src_ids) > max_length:
for st in range(0, len(src_ids), max_length):
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
else:
yield src_ids, tgt_ids
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
assert oversize_rule in ("drop", "head", "segment")
if data is None:
yield from ()
return
if "output" not in data or not data["output"]:
yield from ()
return
if "input" not in data or data["input"] is None:
data["input"] = ""
src_ids = [tokenizer.bos_id]
tgt_ids = []
has_input = False
is_segment_reenter = False
# Use incremental tokenization to avoid waiting for a long document
MAX_CHUNK_LENGTH = max_length * 10
for part in ("input", "output"):
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
while l < len(data[part]):
try:
current_slice = data[part][l:r]
if not current_slice:
break
#token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
token_ids = tokenizer.encode(current_slice)
except:
#print("Error in data[part][l:r] {}".format(data))
yield from ()
return
if part == "input":
if len(token_ids) > 0:
has_input = True
if len(token_ids) >= max_length - 2: # input len must < max_length
yield from ()
return
src_ids.extend(token_ids)
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
l = r
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
else:
if len(token_ids) + len(tgt_ids) >= max_length:
if oversize_rule == "drop":
yield from ()
return
elif oversize_rule == "head":
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
elif oversize_rule == "segment":
instruction_rest_space = max_length - 1 - len(token_ids)
if has_input: # is instruction data
if (
do_compact
and len(src_ids) >= 128 # avoid too short instruction info lost
and instruction_rest_space / len(src_ids) > 0.8
): # can be squeezed into max length
inputs_len = len(src_ids)
keep_len = instruction_rest_space // 2
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_id)
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else: # else use head rule
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
else: # normal segment
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids)
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
yield src_ids[:max_length], tgt_ids[:max_length]
src_ids = src_ids[max_length:]
tgt_ids = tgt_ids[max_length:]
# sliding input str window
consumed_str = tokenizer.decode(selected_token_ids)
l += len(consumed_str)
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
is_segment_reenter = True
else:
if (is_segment_reenter and len(token_ids) > 8) or (
not is_segment_reenter and len(token_ids) > 0
): # is segmented LM data
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_id)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else:
yield from ()
return
class SegmentedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg,
tokenizer,
max_length=1024,
transform_func=None,
nthreads=1,
prefetch_slice=3,
slice_size=500,
do_compact=False,
):
super(SegmentedDataset, self).__init__()
self.segment = functools.partial(
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
)
self.cfg = cfg
self.max_length = max_length
self.nthreads = nthreads
self.transform_func = transform_func
self.prefetch_slice = prefetch_slice
self.slice_size = slice_size
self.abs_weight = cfg.get("abs_weight", None)
self.task_name = cfg["task_name"]
self.dataset_name = cfg["dataset_name"]
self.oversize_rule = cfg["oversize_rule"]
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
self.exhausted = False
self.iterator = None
self.counter = 0
self.allow_repeat = cfg.get("allow_repeat", True)
self.used = BitSet()
self.init_ave_tokens()
def init_ave_tokens(
self,
):
try:
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
except FileNotFoundError:
bmt.print_rank(
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
)
shm = SharedMemory(
create=True,
size=ctypes.sizeof(ctypes.c_float),
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
)
# 使用共享内存
shared_value = ctypes.c_float.from_buffer(shm.buf)
_ave_tokens = self.cfg.get(
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
)
if _ave_tokens > self.max_length:
_ave_tokens = self.max_length
bmt.print_rank(
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
_ave_tokens, self.max_length
)
)
shared_value.value = _ave_tokens
# 不再需要 shared_value 时,删除引用
del shared_value
# 现在可以安全地关闭共享内存
shm.close()
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
@property
def ave_tokens(
self,
):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
tmp = shared_value.value
del shared_value
existing_shm.close()
return tmp
def ave_tokens_update(self, length):
existing_shm = SharedMemory(
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
) # -1 # default length
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
if shared_value.value < 0:
shared_value.value = float(length)
else:
shared_value.value = 0.98 * shared_value.value + 0.02 * length
del shared_value
existing_shm.close()
def size(self):
return self.dataset.size()
def __iter__(self):
self.iterate()
return self
def reset(self):
self.exhausted = False
if self.iterator is not None:
self.iterator.close()
self.iterator = None
self.used = BitSet()
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
def transform(self, data: dict) -> dict:
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
weight = weight / weight.sum()
num_incontext = np.random.choice(weight.shape[0], p=weight)
return self.transform_func(data, num_incontext, random.Random())
def segment_iterate(self, sample_iter):
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
for src_ids, tgt_ids in self.segment(self.transform(data)):
self.ave_tokens_update(len(src_ids)) # 0 for input ids
yield src_ids, tgt_ids, index
def iterate(self):
# make the dataset itself an iterator
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
self.iterator = self.segment_iterate(sample_iter)
def __next__(self):
# advance the task iterator
if self.iterator is None:
self.iterate()
try:
return next(self.iterator)
except StopIteration:
self.exhausted = True
return None
def load_state_dict(self, state_dict):
if state_dict.get("exhausted", False):
self.exhausted = True
self.used = BitSet()
else:
used = state_dict.get("used", BitSet())
if len(used) == len(self.dataset):
self.exhausted = True
self.used = BitSet()
else:
self.exhausted = False
self.used = used
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
def state_dict(self):
if len(self.used) == len(self.dataset):
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
else:
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
def update_state(self, indice):
self.used.update(indice)
class MixedIndexedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg_path: str,
cfg_json_str,
tokenizer,
max_length,
weight_by_size: bool = True,
nthreads=5,
prefetch_slice=100,
parallel_loading=False,
vdc_sampling=False,
update_weights_frequency=1,
seed=42,
):
super(MixedIndexedDataset, self).__init__()
self.set_seed(seed + bmt.rank())
self.weight_by_size = weight_by_size
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_id
self.bos_token_id = self.tokenizer.bos_id
self.path2transform = dict()
self.task_dict = OrderedDict()
self.nthreads = nthreads
self.prefetch_slice = prefetch_slice
# useful for indexing
self.tasks = []
self.names = []
# ending of iteration
self.remain = 0
self.max_length = max_length
self.vdc_sampling = vdc_sampling
if self.vdc_sampling:
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
self.update_weights_frequency = update_weights_frequency
self.path2transform = dict()
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
if parallel_loading:
self.parallel_load(cfgs, max_workers=None)
else:
self.sequential_load(cfgs)
self.weights = None
self.update_weights()
def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_task(self, cfg):
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
return task
def sequential_load(self, cfgs):
self.cfgs = cfgs
for cfg in cfgs:
# python3.7 and later preserves insertion order to dictionary
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
self.task_dict[task.task_name] = task
self.tasks.append(task)
self.names.append(task.task_name)
self.remain += 1
self.weights = None
self.update_weights()
def load_state_dict(self, state_dict):
missing_keys = []
for name, task in self.task_dict.items():
if name in state_dict:
task.load_state_dict(state_dict[name])
else:
missing_keys.append(name)
self.update_weights()
return missing_keys
def save_state_dict(self, path):
state_dict = {}
for name, task in self.task_dict.items():
_state_dict = task.state_dict()
if isinstance(_state_dict["used"], BitSet):
bitset = _state_dict["used"]
_file_name = bitset.save(path)
_state_dict["used"] = _file_name
state_dict[name] = _state_dict
else:
state_dict[name] = task.state_dict()
torch.save(state_dict, path)
logger.info("Dataset state saved")
def update_states(self, task_ids, indice):
is_dict = isinstance(indice, dict)
uniq = torch.unique(task_ids)
for idx in uniq:
idx = idx.item()
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
self.tasks[idx].update_state(indexes)
def get_transform_func(self, module_name: str, transform_script_path):
if transform_script_path is None:
# allow null transform
return lambda data, num_incontext, rand: data
module_name = "fm9g_live.transforms.{}".format(module_name)
if transform_script_path not in self.path2transform:
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
spec = importlib.util.spec_from_loader(loader.name, loader)
if spec is None:
raise RuntimeError("Spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self.path2transform[transform_script_path] = {
"loader": loader,
"module": mod,
"last_mtime": 0,
}
transform_script_info = self.path2transform[transform_script_path]
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
if curr_mtime > transform_script_info["last_mtime"]:
transform_script_info["last_mtime"] = curr_mtime
transform_script_info["loader"].exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
return transform_func
def update_weights(self):
task0 = self.tasks[0]
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
weights = []
for task in self.tasks:
if task.exhausted:
weights.append(0)
else:
if task.ave_tokens == -1:
weights.append(task.abs_weight / self.max_length)
else:
weights.append(task.abs_weight / task.ave_tokens)
weights = np.array(weights)
else:
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
if self.weight_by_size:
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
weights *= sizes
self.weights = weights / weights.sum()
def __iter__(self):
for task in self.tasks:
task.iterate()
return self
def __next__(self):
step = 1
while True:
if self.remain == 0:
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
raise StopIteration
if self.vdc_sampling:
idx = next(self.vdc_gen)(self.weights)
else:
idx = np.random.choice(len(self.weights), p=self.weights)
data = next(self.tasks[idx])
if data is None:
if self.tasks[idx].allow_repeat:
# _runtime_ave = self.tasks[idx].ave_tokens
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
# self.tasks[idx] = SegmentedDataset(
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
# )
# self.tasks[idx].ave_tokens_update(_runtime_ave)
self.tasks[idx].reset()
else:
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
self.tasks[idx].exhaust = True
self.remain -= 1
continue
if step % self.update_weights_frequency == 0:
self.update_weights()
step += 1
return dict(
task_id=idx,
input=data[0],
target=data[1],
index=data[2],
is_long=self.tasks[idx].cfg.get("is_long", False),
)
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
self.max_total_length = batch_size * max_length
self.batch_size = 1
# setting compact=True concats segments orignated from the same input
# into a long sequence. the relative order of segments should be preserved
# in mixed_dataset, e.g.,
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
self.compact = compact
self.total_length = 0
self.task2seqs = defaultdict(list)
self.mixed_dataset = mixed_dataset
self._max_length = max_length
self._pose_prob = pose_prob
self._pose_scaling_factor = pose_scaling_factor
if self._pose_prob > 0.0:
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
else:
self._scaled_max_length = max_length
def put(self, sample):
self.total_length += len(sample["target"])
task_id = sample["task_id"]
if self.compact and self.task2seqs[task_id]:
last = self.task2seqs[task_id][-1]
if last["target"][-1] != self.mixed_dataset.eos_token_id:
# concatenate sequantial segments for longer context modeling: why not?
last["input"].extend(sample["input"])
last["target"].extend(sample["target"])
return
self.task2seqs[task_id].append(sample)
def _pose_preprocess(
self,
input_ids: NDArray[np.int32],
) -> NDArray[np.int32]:
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
"""
len_chunk = min(len(input_ids), self._max_length)
len_input = len(input_ids)
# Chunk input randomly to fit max_length if needed
lt1 = 0
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
# Generate PoSE position ids
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
len_position_ids = len(position_ids)
lt = 0
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
position_ids[: rt1 - lt1] += lt
position_ids[rt1 - lt1 :] += rt
return position_ids
def pop(self):
indexes = defaultdict(list)
lengths = []
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
span_begin = 0
for samples in self.task2seqs.values():
while samples:
sample = samples.pop()
span_end = span_begin + len(sample["input"])
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
_span_position_ids = self._pose_preprocess(sample["input"])
else:
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
lengths.append(len(sample["target"]))
indexes[int(sample["task_id"])].append(sample["index"])
self.total_length -= len(sample["target"])
span_begin = span_end
cu_seqlens = torch.cat(
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
dim=0,
).int()
batch = {
"inputs": inputs,
"targets": targets,
"task_ids": task_ids,
"indexes": indexes,
# adhere to flash attention interface
"cu_seqlens": cu_seqlens,
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
"lengths": torch.tensor(sum(lengths)).int(),
"task_names": self.mixed_dataset.names,
"position_ids": position_ids,
}
return batch
def will_be_full(self, sample):
return self.total_length + len(sample["target"]) > self.max_total_length
def __iter__(self):
for sample in self.mixed_dataset:
if self.will_be_full(sample):
yield self.pop()
self.put(sample)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader, tp_size=1, tp_rank=0):
self.loader = iter(loader)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
if self.tp_size > 1:
if self.tp_rank == 0:
data = next(self.loader)
print("Rank {}, Preload data done.".format(bmt.rank()))
d = {}
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
for key in data.keys():
if isinstance(data[key], torch.Tensor):
np_cur_data = data[key].cpu().numpy()
bs = np_cur_data.tobytes()
fb.write(bs)
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
elif isinstance(data[key], np.ndarray):
bs = data[key].tobytes()
fb.write(bs)
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
else:
d[key] = data[key]
try:
_ = json.dumps(d)
except TypeError:
print(d)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
json.dump(d, f)
bmt.synchronize()
if self.tp_rank != 0:
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
data = json.load(f)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
bs = fb.read()
offset = 0
for key in data.keys():
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
nw_offset = offset + data[key][2]
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
data[key][3:]
)
offset = nw_offset
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
nw_offset = offset + data[key][2]
data[key] = torch.from_numpy(
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
.reshape(data[key][3:])
.copy()
)
offset = nw_offset
self.data = data
else:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key].record_stream(torch.cuda.current_stream())
data = copy.deepcopy(self.data)
self.preload()
return data
def __iter__(self):
return self

View File

@ -1,827 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: ouzebin <ouzebin@zhihu.com>
# @date: 2023/09/27
import copy
import ctypes
import functools
import importlib
import importlib.util
import json
import logging
import multiprocessing as mp
import os
import random
from collections import defaultdict
from collections import OrderedDict
from multiprocessing import Lock
from multiprocessing import Process
from multiprocessing.shared_memory import SharedMemory
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import bmtrain as bmt
import numpy as np
import requests
import torch
from numpy.typing import NDArray
from tenacity import retry
from tenacity import stop_after_attempt
from tenacity import wait_fixed
from tenacity import wait_random
from fm9g.dataset import PrefetchDecodeDataset
from fm9g.utils.bitset import BitSet
from fm9g.utils.vdc_sampling import van_der_corput
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
from .flask_ps import app as flask_ps
logger = logging.getLogger(__name__)
IGNORE_TGT = -100
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
if cfg_json_str is not None:
cfgs = json.loads(cfg_json_str)
else:
with open(cfg_path, "r", encoding="utf-8") as fin:
cfgs = json.load(fin)
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
path_dict = None
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
try:
with open(platform_config_path, "r") as f:
platform_cfg = json.load(f)
path_dict = platform_cfg["dataset_map"]
if bmt.rank() == 0:
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
except Exception as e:
if bmt.rank() == 0:
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
task_name2dataset_name = dict()
for idx, cfg in enumerate(cfgs):
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
# to be delibrately annoying :)
if cfg["task_name"] in task_name2dataset_name:
raise ValueError(
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
)
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
assert "path" in cfg and isinstance(cfg["path"], str)
# if path_dict is not None:
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
# dealing with optional configs
if "weight" in cfg:
assert isinstance(cfg["weight"], (float, int))
else:
cfg["weight"] = 1.0
if "oversize_rule" in cfg:
assert cfg["oversize_rule"] in ("drop", "head", "segment")
else:
cfg["oversize_rule"] = "segment"
if "transforms" in cfg:
assert isinstance(cfg["transforms"], str)
# dealing with relative path
if not cfg["transforms"].startswith("/"):
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
if not cfg["transforms"]:
cfg["transforms"] = None
else:
cfg["transforms"] = None
if "incontext_weight" in cfg:
assert isinstance(cfg["incontext_weight"], (list, tuple))
else:
cfg["incontext_weight"] = [1.0]
cfg["id"] = idx
# dataset and iterator will be built
return cfgs
def data2ids(data, tokenizer, max_length):
text = "\n".join(
[
data.get("title", "").strip(),
data.get("question", "").strip(),
data.get("answer", "").strip(),
data.get("abstract", "").strip(),
data.get("text", "").strip(),
data.get("code", "").strip(),
]
).strip()
if not text:
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
yield from ()
return
# suppress the annoying warning from tokenizer
ids = (
[tokenizer.bos_token_id]
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
+ [tokenizer.eos_token_id]
)
src_ids = ids[0:-1]
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
if len(src_ids) > max_length:
for st in range(0, len(src_ids), max_length):
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
else:
yield src_ids, tgt_ids
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
assert oversize_rule in ("drop", "head", "segment")
if data is None:
yield from ()
return
if "output" not in data or not data["output"]:
yield from ()
return
if "input" not in data:
data["input"] = ""
src_ids = [tokenizer.bos_token_id]
tgt_ids = []
has_input = False
is_segment_reenter = False
# Use incremental tokenization to avoid waiting for a long document
MAX_CHUNK_LENGTH = max_length * 10
for part in ("input", "output"):
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
while l < len(data[part]):
current_slice = data[part][l:r]
if not current_slice:
break
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
if part == "input":
if len(token_ids) > 0:
has_input = True
if len(token_ids) >= max_length - 2: # input len must < max_length
yield from ()
return
src_ids.extend(token_ids)
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
l = r
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
else:
if len(token_ids) + len(tgt_ids) >= max_length:
if oversize_rule == "drop":
yield from ()
return
elif oversize_rule == "head":
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
elif oversize_rule == "segment":
instruction_rest_space = max_length - 1 - len(token_ids)
if has_input: # is instruction data
if (
do_compact
and len(src_ids) >= 128 # avoid too short instruction info lost
and instruction_rest_space / len(src_ids) > 0.8
): # can be squeezed into max length
inputs_len = len(src_ids)
keep_len = instruction_rest_space // 2
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else: # else use head rule
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids[:-1])
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids[:max_length], tgt_ids[:max_length]
return
else: # normal segment
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
src_ids.extend(selected_token_ids)
tgt_ids.extend(selected_token_ids)
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
yield src_ids[:max_length], tgt_ids[:max_length]
src_ids = src_ids[max_length:]
tgt_ids = tgt_ids[max_length:]
# sliding input str window
consumed_str = tokenizer.decode(selected_token_ids)
l += len(consumed_str)
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
is_segment_reenter = True
else:
if (is_segment_reenter and len(token_ids) > 8) or (
not is_segment_reenter and len(token_ids) > 0
): # is segmented LM data
src_ids.extend(token_ids)
tgt_ids.extend(token_ids)
tgt_ids.append(tokenizer.eos_token_id)
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
yield src_ids, tgt_ids
else:
yield from ()
return
class SegmentedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg,
tokenizer,
max_length=1024,
transform_func=None,
nthreads=1,
prefetch_slice=3,
slice_size=500,
do_compact=False,
is_local=False,
):
def get_full_qualified_name(func):
module_name = func.__module__
qual_name = func.__qualname__
return f"{module_name}.{qual_name}"
super(SegmentedDataset, self).__init__()
self.segment = functools.partial(
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
)
self.cfg = cfg
self.max_length = max_length
self.nthreads = nthreads
self.transform_func = transform_func
self.prefetch_slice = prefetch_slice
self.slice_size = slice_size
self.abs_weight = cfg.get("abs_weight", None)
self.task_name = cfg["task_name"]
self.dataset_name = cfg["dataset_name"]
self.oversize_rule = cfg["oversize_rule"]
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", False))
self.exhausted = False
self.iterator = None
self.counter = 0
self.allow_repeat = cfg.get("allow_repeat", True)
self.used = set()
self.is_local = is_local
self.port_offset = random.randint(1, 8000) + 1000
if is_local or bmt.rank() == 0:
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
_avg_tokens = cfg.get("ave_tokens_per_line", -1)
_avg_tokens = cfg.get("avg_tokens", _avg_tokens)
requests.post(f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={_avg_tokens}")
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
def set_avg_tokens(self, avg_tokens):
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={avg_tokens}"
response = requests.post(url)
if response.status_code != 200:
self.reset_port_offset()
print(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
raise RuntimeError(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
def update_avg_tokens_by_ema(self, length):
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=update&length={length}"
response = requests.post(url)
if response.status_code != 200:
self.reset_port_offset()
print(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
raise RuntimeError(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
@property
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
def avg_tokens(self):
addr = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"]) + self.port_offset
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}"
response = requests.get(url)
if response.status_code != 200:
self.reset_port_offset()
print(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
raise RuntimeError(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
data = response.json()
return data["avg_tokens"]
def reset_port_offset(self):
self.port_offset = random.randint(1, 8000) + 1000
def size(self):
return self.dataset.size()
def __iter__(self):
self.iterate()
return self
def reset(self):
self.exhausted = False
if self.iterator is not None:
self.iterator.close()
self.iterator = None
self.used = BitSet()
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
def transform(self, data: dict) -> dict:
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
weight = weight / weight.sum()
num_incontext = np.random.choice(weight.shape[0], p=weight)
return self.transform_func(data, num_incontext, random.Random())
def segment_iterate(self, sample_iter):
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
for src_ids, tgt_ids in self.segment(self.transform(data)):
self.update_avg_tokens_by_ema(len(src_ids)) # 0 for input ids
yield src_ids, tgt_ids, index
def iterate(self):
# make the dataset itself an iterator
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
self.iterator = self.segment_iterate(sample_iter)
def __next__(self):
# advance the task iterator
if self.iterator is None:
self.iterate()
try:
return next(self.iterator)
except StopIteration:
self.exhausted = True
return None
def load_state_dict(self, state_dict):
if state_dict.get("exhausted", False):
self.exhausted = True
self.used = BitSet()
else:
used = state_dict.get("used", BitSet())
if len(used) == len(self.dataset):
self.exhausted = True
self.used = BitSet()
else:
self.exhausted = False
self.used = used
_avg_tokens = state_dict.get("ave_tokens", -1)
_avg_tokens = state_dict.get("avg_tokens", _avg_tokens)
if self.avg_tokens == -1 or self.avg_tokens < _avg_tokens:
self.set_avg_tokens(_avg_tokens)
def state_dict(self):
if len(self.used) == len(self.dataset):
return dict(exhausted=True, used=BitSet(), avg_tokens=self.avg_tokens)
else:
return dict(exhausted=False, used=self.used, avg_tokens=self.avg_tokens)
def update_state(self, indice):
self.used.update(indice)
class MixedIndexedDataset(torch.utils.data.IterableDataset):
def __init__(
self,
cfg_path: str,
cfg_json_str,
tokenizer,
max_length,
weight_by_size: bool = True,
nthreads=5,
prefetch_slice=100,
parallel_loading=False,
vdc_sampling=False,
update_weights_frequency=1,
seed=42,
):
if bmt.rank() == 0:
port = int(os.environ["MASTER_PORT"]) + 2188
self.flask_ps_proc = mp.Process(target=flask_ps.run, kwargs={"host": "0.0.0.0", "port": port})
self.flask_ps_proc.start()
super(MixedIndexedDataset, self).__init__()
self.set_seed(seed + bmt.rank())
self.weight_by_size = weight_by_size
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.bos_token_id = self.tokenizer.bos_token_id
self.path2transform = dict()
self.task_dict = OrderedDict()
self.nthreads = nthreads
self.prefetch_slice = prefetch_slice
# useful for indexing
self.tasks = []
self.names = []
# ending of iteration
self.remain = 0
self.max_length = max_length
self.vdc_sampling = vdc_sampling
if self.vdc_sampling:
self._vdc_values = [van_der_corput(i) for i in range(100000)]
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
self.update_weights_frequency = update_weights_frequency
self.path2transform = dict()
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
if parallel_loading:
self.parallel_load(cfgs, max_workers=None)
else:
self.sequential_load(cfgs)
self.weights = None
self.update_weights()
def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def load_task(self, cfg):
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
return task
def sequential_load(self, cfgs):
self.cfgs = cfgs
for cfg in cfgs:
# python3.7 and later preserves insertion order to dictionary
logger.info(f"Loading {cfg['path']}")
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
task = SegmentedDataset(
cfg,
self.tokenizer,
self.max_length,
transform_func=transform_func,
nthreads=self.nthreads,
prefetch_slice=self.prefetch_slice,
do_compact=cfg.get("do_compact", False), # dataset level do_compact
)
self.task_dict[task.task_name] = task
self.tasks.append(task)
self.names.append(task.task_name)
self.remain += 1
self.weights = None
self.update_weights()
def load_state_dict(self, state_dict):
missing_keys = []
for name, task in self.task_dict.items():
if name in state_dict:
task.load_state_dict(state_dict[name])
else:
missing_keys.append(name)
self.update_weights()
return missing_keys
def save_state_dict(self, path):
state_dict = {}
for name, task in self.task_dict.items():
_state_dict = task.state_dict()
if isinstance(_state_dict["used"], BitSet):
bitset = _state_dict["used"]
_file_name = bitset.save(path)
_state_dict["used"] = _file_name
state_dict[name] = _state_dict
else:
state_dict[name] = task.state_dict()
torch.save(state_dict, path)
logger.info("Dataset state saved")
def update_states(self, task_ids, indice):
is_dict = isinstance(indice, dict)
uniq = torch.unique(task_ids)
for idx in uniq:
idx = idx.item()
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
self.tasks[idx].update_state(indexes)
def get_transform_func(self, module_name: str, transform_script_path):
if transform_script_path is None:
# allow null transform
return lambda data, num_incontext, rand: data
if "/" in module_name:
module_name = "fm9g_live.transforms.{}".format(module_name.split("/")[-1])
else:
module_name = "fm9g_live.transforms.{}".format(module_name)
if transform_script_path not in self.path2transform:
# loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
# spec = importlib.util.spec_from_loader(loader.name, loader)
spec = importlib.util.spec_from_file_location(module_name, transform_script_path)
if spec is None:
raise RuntimeError("Spec is none! {}".format(module_name))
mod = importlib.util.module_from_spec(spec)
self.path2transform[transform_script_path] = {
"module": mod,
"last_mtime": 0,
}
transform_script_info = self.path2transform[transform_script_path]
curr_mtime = float(os.path.getmtime(transform_script_path))
if curr_mtime > transform_script_info["last_mtime"]:
transform_script_info["last_mtime"] = curr_mtime
# load module
spec.loader.exec_module(transform_script_info["module"])
transform_func = getattr(transform_script_info["module"], "transform", None)
if transform_func is None:
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
return transform_func
def update_weights(self):
task0 = self.tasks[0]
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
weights = []
for task in self.tasks:
if task.exhausted:
weights.append(0)
else:
if task.avg_tokens == -1:
weights.append(task.abs_weight / self.max_length)
else:
weights.append(task.abs_weight / task.avg_tokens)
weights = np.array(weights)
else:
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
if self.weight_by_size:
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
weights *= sizes
self.weights = weights / weights.sum()
def __iter__(self):
for task in self.tasks:
task.iterate()
return self
def __next__(self):
step = 1
while True:
if self.remain == 0:
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
raise StopIteration
if self.vdc_sampling:
idx = next(self.vdc_gen)(self.weights)
else:
idx = np.random.choice(len(self.weights), p=self.weights)
data = next(self.tasks[idx])
if step % self.update_weights_frequency == 0:
self.update_weights()
if data is None:
if self.tasks[idx].allow_repeat:
# _runtime_ave = self.tasks[idx].avg_tokens
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
# self.tasks[idx] = SegmentedDataset(
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
# )
# self.tasks[idx].avg_tokens_update(_runtime_ave)
self.tasks[idx].reset()
else:
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
self.tasks[idx].exhaust = True
self.remain -= 1
continue
step += 1
return dict(
task_id=idx,
input=data[0],
target=data[1],
index=data[2],
is_long=self.tasks[idx].cfg.get("is_long", False),
)
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
self.max_total_length = batch_size * max_length
self.batch_size = 1
# setting compact=True concats segments orignated from the same input
# into a long sequence. the relative order of segments should be preserved
# in mixed_dataset, e.g.,
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
self.compact = compact
self.total_length = 0
self.task2seqs = defaultdict(list)
self.mixed_dataset = mixed_dataset
self._max_length = max_length
self._pose_prob = pose_prob
self._pose_scaling_factor = pose_scaling_factor
if self._pose_prob > 0.0:
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
else:
self._scaled_max_length = max_length
def put(self, sample):
self.total_length += len(sample["target"])
task_id = sample["task_id"]
if self.compact and self.task2seqs[task_id]:
last = self.task2seqs[task_id][-1]
if last["target"][-1] != self.mixed_dataset.eos_token_id:
# concatenate sequantial segments for longer context modeling: why not?
last["input"].extend(sample["input"])
last["target"].extend(sample["target"])
return
self.task2seqs[task_id].append(sample)
def _pose_preprocess(
self,
input_ids: NDArray[np.int32],
) -> NDArray[np.int32]:
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
"""
len_chunk = min(len(input_ids), self._max_length)
len_input = len(input_ids)
# Chunk input randomly to fit max_length if needed
lt1 = 0
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
# Generate PoSE position ids
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
len_position_ids = len(position_ids)
lt = 0
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
position_ids[: rt1 - lt1] += lt
position_ids[rt1 - lt1 :] += rt
return position_ids
def pop(self):
indexes = defaultdict(list)
lengths = []
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
span_begin = 0
for samples in self.task2seqs.values():
while samples:
sample = samples.pop()
span_end = span_begin + len(sample["input"])
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
_span_position_ids = self._pose_preprocess(sample["input"])
else:
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
lengths.append(len(sample["target"]))
indexes[int(sample["task_id"])].append(sample["index"])
self.total_length -= len(sample["target"])
span_begin = span_end
cu_seqlens = torch.cat(
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
dim=0,
).int()
batch = {
"inputs": inputs,
"targets": targets,
"task_ids": task_ids,
"indexes": indexes,
# adhere to flash attention interface
"cu_seqlens": cu_seqlens,
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
"lengths": torch.tensor(sum(lengths)).int(),
"task_names": self.mixed_dataset.names,
"position_ids": position_ids,
}
return batch
def will_be_full(self, sample):
return self.total_length + len(sample["target"]) > self.max_total_length
def __iter__(self):
for sample in self.mixed_dataset:
if self.will_be_full(sample):
yield self.pop()
self.put(sample)
class CudaPrefetcher(Iterable):
"""
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
"""
def __init__(self, loader, tp_size=1, tp_rank=0):
self.loader = iter(loader)
self.tp_size = tp_size
self.tp_rank = tp_rank
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
if self.tp_size > 1:
if self.tp_rank == 0:
data = next(self.loader)
print("Rank {}, Preload data done.".format(bmt.rank()))
d = {}
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
for key in data.keys():
if isinstance(data[key], torch.Tensor):
np_cur_data = data[key].cpu().numpy()
bs = np_cur_data.tobytes()
fb.write(bs)
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
elif isinstance(data[key], np.ndarray):
bs = data[key].tobytes()
fb.write(bs)
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
else:
d[key] = data[key]
try:
_ = json.dumps(d)
except TypeError:
print(d)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
json.dump(d, f)
bmt.synchronize()
if self.tp_rank != 0:
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
data = json.load(f)
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
bs = fb.read()
offset = 0
for key in data.keys():
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
nw_offset = offset + data[key][2]
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
data[key][3:]
)
offset = nw_offset
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
nw_offset = offset + data[key][2]
data[key] = torch.from_numpy(
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
.reshape(data[key][3:])
.copy()
)
offset = nw_offset
self.data = data
else:
self.data = next(self.loader)
except StopIteration:
self.data = None
return
with torch.cuda.stream(self.stream):
for key in self.data.keys():
if isinstance(self.data[key], torch.Tensor):
self.data[key] = self.data[key].cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
data = copy.deepcopy(self.data)
self.preload()
return data
def __iter__(self):
return self

View File

@ -1,79 +0,0 @@
import torch
import torch.nn.functional as F
import bmtrain as bmt
def _linear_backward(grad_output, x, weight, bias):
grad_x = grad_weight = grad_bias = None
if x.requires_grad:
grad_x = grad_output.matmul(weight)
if weight.requires_grad:
grad_weight = grad_output.reshape(-1,
grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
if bias is not None and bias.requires_grad:
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
return grad_x, grad_weight, grad_bias
class OpAttnPipeSP(torch.autograd.Function):
@staticmethod
def forward(ctx, q_w, k_w, v_w, q_b, w_b, v_b, x, cache_kv, cache_kv_inp, cu_seqlens_q, cu_seqlens_k, max_seqlen):
ctx.save_for_backward(x, q_w, k_w, v_w, q_b, w_b, v_b)
if cache_kv.numel() = 0:
q = F.linear(x, q_w, q_b)
k = F.linear(x, k_w, w_b)
v = F.linear(x, v_w, v_b)
else:
q = F.linear(x, q_w, q_b)
k = F.linear(x, k_w, w_b)
v = F.linear(x, v_w, v_b)
k = torch.cat([cache_kv[0], k], dim=1)
v = torch.cat([cache_kv[1], v], dim=1)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen, max_seqlen, 0, causal=True, window_size=(-1,-1), alibi_slopes=None, deterministic=False, return_attn_probs=False
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.max_seqlen_q = max_seqlen
ctx.max_seqlen_k = max_seqlen
return F.linear(x, weight, bias)
@staticmethod
def backward(ctx, grad_output):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
False,
(-1,-1),
None,
False,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
d_xk, d_wk, d_bk = _linear_backward(dk, x, k_w, k_b)
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None

View File

@ -1,332 +0,0 @@
import math
from typing import Optional
from typing import Tuple
try:
from .flash_triton import FlashAttnFunc
except:
FlashAttnFunc = None
import bmtrain as bmt
import torch
from einops import rearrange
from .linear import ColumnParallelLinear
from .linear import Linear
from .position_embedding import apply_chatglm_rotary_pos_emb
from flash_attn.flash_attn_interface import flash_attn_varlen_func
#try:
# from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
# from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
# from flash_attn.flash_attn_interface import flash_attn_varlen_func
#except:
# flash_attn_varlen_func = None
try:
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
except:
pad_input = None
unpad_input = None
class OpFlash(torch.autograd.Function):
@staticmethod
def forward(ctx, self, record, q, k, v, cu_seqlens, max_seqlen, dropout_p, causal):
ctx.self = self
ctx.cu_seqlens = cu_seqlens
ctx.max_length = max_seqlen
ctx.dropout_p = dropout_p
ctx.causal = causal
ctx.softmax_scale = q.shape[-1] ** (-0.5)
if not record and "out" in self._layer_dict:
out = self._layer_dict.pop("out")
softmax_lse = self._layer_dict.pop("softmax_lse")
rng_state = self._layer_dict.pop("rng_state")
else:
out, _, _, _, _, softmax_lse, _, rng_state = _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
ctx.softmax_scale,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
return_softmax=False,
)
if record:
self._layer_dict["out"] = out
self._layer_dict["softmax_lse"] = softmax_lse
self._layer_dict["rng_state"] = rng_state
ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state)
return out
@staticmethod
def backward(ctx, dout):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
ctx.cu_seqlens,
ctx.cu_seqlens,
ctx.max_length,
ctx.max_length,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
(-1,-1),
None,
False,
rng_state=rng_state,
)
return None, None, dq, dk, dv, None, None, None, None
class Attention(bmt.DistributedModule):
def __init__(
self,
dim_model: int,
num_heads: int,
num_kv_heads: int,
dim_head: int,
dtype: torch.dtype = torch.half,
dropout_p: Optional[float] = None,
scale: bool = True,
add_qkv_bias: bool = False,
use_flash_attn: bool = False,
tp: int = 0,
) -> None:
super().__init__()
self.dim_model = dim_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_groups = num_heads // num_kv_heads
self.dim_head = dim_head
self.project_q = Linear(
self.dim_model,
self.num_heads * self.dim_head,
bias=add_qkv_bias,
dtype=dtype,
scale=scale,
tp=tp,
)
self.project_k = Linear(
self.dim_model,
self.num_kv_heads * self.dim_head,
bias=add_qkv_bias,
dtype=dtype,
scale=scale,
tp=tp,
)
self.project_v = Linear(
self.dim_model,
self.num_kv_heads * self.dim_head,
bias=add_qkv_bias,
dtype=dtype,
scale=scale,
tp=tp,
)
self.attention_out = Linear(
self.num_heads * self.dim_head,
self.dim_model,
dtype=dtype,
scale=scale,
tp=tp * 2,
)
self.softmax = torch.nn.Softmax(dim=-1)
if dropout_p is not None:
self.dropout = torch.nn.Dropout(p=dropout_p)
self.dropout_p = dropout_p
else:
self.dropout = None
self.use_flash_attn = use_flash_attn
self._layer_dict = {}
def forward(
self,
hidden_q: torch.Tensor,
hidden_kv: torch.Tensor,
attention_mask: torch.BoolTensor,
position_bias: torch.Tensor,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
pos_bias_type: Optional[str] = "relative",
length_mask: Optional[torch.Tensor] = None,
attention_mask_bias: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: int = None,
position_ids: Optional[torch.Tensor] = None,
):
"""This model inherits from bmt.DistributedModule.
Args:
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
Return:
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
""" # noqa: E501
len_q = hidden_q.size(1)
len_k = hidden_kv.size(1)
if isinstance(self.project_q, ColumnParallelLinear):
assert hidden_q.data_ptr() == hidden_kv.data_ptr()
if self.project_q.scale and self.project_q.scale_before:
hidden_q = hidden_q / math.sqrt(self.project_q.dim_in)
hidden_q = bmt.nn.OpParallelLinear.apply(
hidden_q,
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0)
if self.project_q.bias is not None
else None,
True,
False,
False,
None,
)
if self.project_q.scale and not self.project_q.scale_before:
hidden_q = hidden_q / math.sqrt(self.project_q.dim_in)
block_size = hidden_q.shape[-1] // (self.head_groups + 1 + 1)
h_q = hidden_q[..., : block_size * self.head_groups]
h_k = hidden_q[..., block_size * self.head_groups : block_size * (self.head_groups + 1)]
h_v = hidden_q[..., block_size * (self.head_groups + 1) :]
else:
h_q = self.project_q(hidden_q)
h_k = self.project_k(hidden_kv)
h_v = self.project_v(hidden_kv)
batch_size = h_q.size(0)
if not self.use_flash_attn:
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
h_q = h_q.view(batch_size, len_q, -1, self.dim_head).permute(0, 2, 1, 3)
h_k = h_k.view(batch_size, len_k, -1, self.dim_head).permute(0, 2, 1, 3)
h_v = h_v.view(batch_size, len_k, -1, self.dim_head).permute(0, 2, 1, 3)
if pos_bias_type == "rotary":
# b h s d
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
elif pos_bias_type == "chatglm_rotary":
h_q = apply_chatglm_rotary_pos_emb(h_q, position_bias)
h_k = apply_chatglm_rotary_pos_emb(h_k, position_bias)
if past_kv is not None:
h_k = torch.cat([past_kv[0], h_k], dim=-2)
h_v = torch.cat([past_kv[1], h_v], dim=-2)
len_k = h_k.size(-2)
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
# (b, n_kv_h, n_h_groups*len_q, d_h) @ (b, n_kv_h, d_h, len_k) -> (b, n_kv_h, n_h_groups*len_q, len_k) -> (b, n_h, len_q, len_k)
if self.head_groups == 1:
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
else:
score = torch.matmul(
h_q.reshape(batch_size, -1, self.head_groups * len_q, self.dim_head),
h_k.transpose(-1, -2),
).view(
batch_size, -1, len_q, len_k
) # / math.sqrt(self.dim_head) moved to line 75~76
if pos_bias_type == "relative":
if len_q == 1: # inference with cache
if len(position_bias.size()) == 4:
position_bias = position_bias[:, :, -1:, :]
else:
position_bias = position_bias[:, -1:, :]
score = score + position_bias
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
)
score = self.softmax(score)
score = torch.masked_fill(
score,
attention_mask.view(batch_size, 1, len_q, len_k) == False,
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
)
if self.dropout is not None:
score = self.dropout(score)
# (b, n_h, len_q, len_k) @ (b, n_h, len_k, d_h) -> (b, n_h, len_q, d_h)
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
score = torch.matmul(score.view(batch_size, -1, self.head_groups * len_q, len_k), h_v).view(
batch_size, -1, len_q, self.dim_head
)
score = score.view(batch_size, -1, len_q, self.dim_head).permute(0, 2, 1, 3)
score = score.contiguous().view(batch_size, len_q, -1)
else:
if attention_mask_bias is not None:
assert pos_bias_type == "rotary"
h_q = h_q.view(batch_size, len_q, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_k = h_k.view(batch_size, len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_v = h_v.view(batch_size, len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_q, h_k = position_bias(h_q, h_k, -3)
score = FlashAttnFunc.apply(h_q, h_k, h_v, attention_mask_bias, False, None)
else:
if pos_bias_type == "chatglm_rotary":
raise NotImplemented("No FlashAttn version for ChatGLM at present!")
h_q = h_q.view(batch_size * len_q, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_k = h_k.view(batch_size * len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_v = h_v.view(batch_size * len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
h_q, h_k = position_bias(
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
)
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
print("h_q: ", h_q)
print("cu_seqlens: ", cu_seqlens)
print("max_seqlen: ", max_seqlen)
score = flash_attn_varlen_func(
h_q,
h_k,
h_v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
self.dropout_p,
causal=True,
deterministic=True,
)
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
# Rongqiao change
#score = OpFlash.apply(
# self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
#)
score = score.view(batch_size, len_q, -1)
score = self.attention_out(score)
if use_cache:
return score, (h_k, h_v)
else:
return score

View File

@ -1,265 +0,0 @@
import inspect
import math
import bmtrain as bmt
import torch
import torch.nn.functional as F
def Linear(*args, **kwargs):
tp = kwargs.pop("tp", 0)
if tp == 0:
return NormalLinear(*args, **kwargs)
if tp == 1:
return ColumnParallelLinear(*args, **kwargs)
if tp == 2:
return RowParallelLinear(*args, **kwargs)
class OpLastLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, self, record, x, weight, bias=None):
ctx.self = self
if not record and "r" in self._layer_dict:
ctx.save_for_backward(x, weight, bias)
self._layer_dict.pop("r")
return torch.zeros((*x.shape[:-1], self.out_features), device=x.device, dtype=x.dtype)
else:
ctx.save_for_backward(x, weight, bias)
if record:
self._layer_dict["r"] = True
return F.linear(x, weight, bias)
@staticmethod
def backward(ctx, grad_output):
x, weight, bias = ctx.saved_tensors
grad_x = grad_weight = grad_bias = None
if x.requires_grad:
grad_x = grad_output.matmul(weight)
if weight.requires_grad:
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
if bias is not None and bias.requires_grad:
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
return None, None, grad_x, grad_weight, grad_bias
class LastLinear(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_out: int,
bias: bool = False,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = True,
scale_before: bool = False,
tp: int = 0,
):
super().__init__()
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.scale = scale
self.scale_before = scale_before
if not scale:
init_std = 1 / ((dim_in + dim_out) ** 0.5)
self.weight = bmt.DistributedParameter(
torch.empty((dim_out, dim_in), dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
self.bias = (
bmt.DistributedParameter(
torch.empty(dim_out, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
if bias
else None
)
self._layer_dict = {}
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if self.scale and self.scale_before:
x = x / math.sqrt(self.dim_in)
x = OpLastLinear.apply(self, not torch.is_grad_enabled(), x, self.weight, self.bias)
if self.scale and not self.scale_before:
x = x / math.sqrt(self.dim_in)
return x
class NormalLinear(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_out: int,
bias: bool = False,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = True,
scale_before: bool = False,
):
super().__init__()
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.scale = scale
self.scale_before = scale_before
if not scale:
init_std = 1 / ((dim_in + dim_out) ** 0.5)
self.weight = bmt.DistributedParameter(
torch.empty((dim_out, dim_in), dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
self.bias = (
bmt.DistributedParameter(
torch.empty(dim_out, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
)
if bias
else None
)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if self.scale and self.scale_before:
x = x / math.sqrt(self.dim_in)
if "tp_size" in inspect.signature(bmt.init_distributed).parameters:
x = bmt.nn.OpLinear.apply(x, self.weight, self.bias)
else:
x = F.linear(x, self.weight, self.bias)
if self.scale and not self.scale_before:
x = x / math.sqrt(self.dim_in)
return x
class ColumnParallelLinear(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_out: int,
bias: bool = False,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = True,
scale_before: bool = False,
gather_output=False,
gather_input=True,
):
super().__init__()
assert dim_out % bmt.config["tp_size"] == 0
if not scale:
init_std = 1 / ((dim_in + dim_out) ** 0.5)
dim_out = dim_out // bmt.config["tp_size"]
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.scale = scale
self.scale_before = scale_before
self.gather_input = gather_input
self.gather_output = gather_output
self.weight = bmt.DistributedParameter(
torch.empty((dim_out, dim_in), dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=0,
tp_mode=True,
)
self.bias = (
bmt.DistributedParameter(
torch.empty(dim_out, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=0,
tp_mode=True,
)
if bias
else None
)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if self.scale and self.scale_before:
x = x / math.sqrt(self.dim_in)
x = bmt.nn.OpParallelLinear.apply(x, self.weight, self.bias, self.gather_input, self.gather_output, False, None)
if self.scale and not self.scale_before:
x = x / math.sqrt(self.dim_in)
return x
class RowParallelLinear(bmt.DistributedModule):
def __init__(
self,
dim_in: int,
dim_out: int,
bias: bool = False,
dtype: torch.dtype = torch.half,
init_mean: float = 0.0,
init_std: float = 1,
scale: bool = True,
scale_before: bool = False,
split_input=False,
all_reduce_output=False,
):
super().__init__()
assert dim_in % bmt.config["tp_size"] == 0
if not scale:
init_std = 1 / ((dim_in + dim_out) ** 0.5)
dim_in = dim_in // bmt.config["tp_size"]
self.dim_in = self.in_features = dim_in
self.dim_out = self.out_features = dim_out
self.scale = scale
self.scale_before = scale_before
self.split_input = split_input
self.all_reduce_output = all_reduce_output
self.weight = bmt.DistributedParameter(
torch.empty((dim_out, dim_in), dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=1,
tp_mode=True,
)
self.bias = (
bmt.DistributedParameter(
torch.empty(dim_out, dtype=dtype),
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
tp_split_dim=-1,
tp_mode=True,
)
if bias
else None
)
def forward(self, x: torch.Tensor):
"""
Args:
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
Returns:
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
""" # noqa: E501
if self.scale and self.scale_before:
x = x / math.sqrt(self.dim_in)
x = bmt.nn.OpParallelLinear.apply(
x, self.weight, None, self.split_input, False, self.split_input, 1 if self.all_reduce_output else 2
)
if self.bias is not None:
x = x + self.bias
if self.scale and not self.scale_before:
x = x / math.sqrt(self.dim_in)
return x

View File

@ -1,22 +0,0 @@
from typing import Optional
class EMAValue(object):
def __init__(self, init_value: Optional[float] = None, decay_factor: float = 0.999) -> None:
super().__init__()
self._value = init_value
self._decay_factor = decay_factor
@property
def value(self) -> Optional[float]:
return self._value
def update(self, value: float) -> None:
if self._value is None:
self._value = value
else:
self._value = self._decay_factor * self._value + (1 - self._decay_factor) * value
def update_with_return(self, value: float) -> Optional[float]:
self.update(value)
return self._value

View File

@ -1 +0,0 @@
from .fm9g import FM9GTokenizer

View File

@ -1,247 +0,0 @@
import bmtrain as bmt
import math
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
r"""
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
"""
def __init__(
self, optimizer, start_lr, warmup_iter, end_iter, num_iter=0, lr_end_restart=0, resume_no_optimze=0
) -> None:
self.warmup_iter = warmup_iter
self.end_iter = end_iter
self.optimizer = optimizer
self.num_iter = num_iter
self._current_lr = None
self._start_lr = start_lr
self.start_lr = []
self.lr_end_restart = lr_end_restart
self.resume_step = num_iter
self.resume_no_optimze = resume_no_optimze
for group in self.optimizer.param_groups:
self.start_lr.append(group["lr"])
self.step(self.num_iter)
def get_lr_warmup(self, num_iter, base_lr) -> float:
return base_lr * num_iter / self.warmup_iter
def get_lr_decay(self, num_iter, base_lr) -> float:
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
if progress > 1:
if self.lr_end_restart == 0:
progress = 1
elif self.lr_end_restart == 1:
progress = progress
elif self.lr_end_restart == 2:
progress = int(progress) * 2 + (progress - 1)
return max(base_lr * 0.1, base_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
def get_lr(self, base_lr):
assert self.num_iter >= 0
if self.resume_step + self.resume_no_optimze > self.num_iter:
bmt.print_rank("resume no optimize")
return 0
if self.num_iter < self.warmup_iter:
return self.get_lr_warmup(self.num_iter, base_lr)
else:
return self.get_lr_decay(self.num_iter, base_lr)
@property
def current_lr(self):
return self._current_lr
def step(self, num_iter=None) -> None:
if num_iter is None:
num_iter = self.num_iter + 1
self.num_iter = num_iter
self._current_lr = self.get_lr(self._start_lr)
for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
group["lr"] = self.get_lr(base_lr)
def state_dict(self):
return {
"_start_lr": self._start_lr,
"start_lr": self.start_lr,
"warmup_iter": self.warmup_iter,
"end_iter": self.end_iter,
"num_iter": self.num_iter,
}
def load_state_dict(self, state_dict):
self._start_lr = state_dict["_start_lr"]
self.start_lr = state_dict["start_lr"]
self.warmup_iter = state_dict["warmup_iter"]
self.end_iter = state_dict["end_iter"]
self.num_iter = state_dict["num_iter"]
self.step(self.num_iter)
class WarmupStableDrop(bmt.lr_scheduler.WarmupLRScheduler):
r"""
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
"""
def __init__(
self, optimizer, start_lr, warmup_iter, end_iter, drop_iter=0, num_iter=0, resume_no_optimze=0
) -> None:
self.warmup_iter = warmup_iter
self.end_iter = end_iter
self.drop_iter = drop_iter
self.optimizer = optimizer
self.num_iter = num_iter
self._current_lr = None
self._start_lr = start_lr
self.start_lr = []
self.resume_step = num_iter
self.resume_no_optimze = resume_no_optimze
for group in self.optimizer.param_groups:
self.start_lr.append(group["lr"])
self.step(self.num_iter)
def get_lr_warmup(self, num_iter, base_lr, warmup_iter) -> float:
return base_lr * num_iter / warmup_iter
def get_lr_stable(self, num_iter, base_lr):
return base_lr
def get_lr_drop(self, num_iter, base_lr):
progress = (self.end_iter - num_iter) / self.drop_iter
return base_lr * (0.1 + max(0.9 * (self.end_iter - num_iter) / self.drop_iter, 0))
def get_lr(self, base_lr):
assert self.num_iter >= 0
if self.resume_step + self.resume_no_optimze > self.num_iter:
return self.get_lr_warmup(self.num_iter - self.resume_step, base_lr, self.resume_no_optimze)
if self.num_iter < self.warmup_iter:
return self.get_lr_warmup(self.num_iter, base_lr, self.warmup_iter)
if self.num_iter > self.end_iter - self.drop_iter:
return self.get_lr_drop(self.num_iter, base_lr)
return self.get_lr_stable(self.num_iter, base_lr)
@property
def current_lr(self):
return self._current_lr
def step(self, num_iter=None) -> None:
if num_iter is None:
num_iter = self.num_iter + 1
self.num_iter = num_iter
self._current_lr = self.get_lr(self._start_lr)
for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
group["lr"] = self.get_lr(base_lr)
def state_dict(self):
return {
"_start_lr": self._start_lr,
"start_lr": self.start_lr,
"warmup_iter": self.warmup_iter,
"end_iter": self.end_iter,
"num_iter": self.num_iter,
}
def load_state_dict(self, state_dict):
self._start_lr = state_dict["_start_lr"]
self.start_lr = state_dict["start_lr"]
self.warmup_iter = state_dict["warmup_iter"]
self.end_iter = state_dict["end_iter"]
self.num_iter = state_dict["num_iter"]
self.step(self.num_iter)
class WarmupStableExp(bmt.lr_scheduler.WarmupLRScheduler):
r"""
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
"""
def __init__(
self, optimizer, start_lr, warmup_iter, drop_begin=-1, drop_rate=0.5, drop_iter=0, num_iter=0, resume_no_optimze=0
) -> None:
self.warmup_iter = warmup_iter
self.drop_iter = drop_iter
self.optimizer = optimizer
self.num_iter = num_iter
self._current_lr = None
self._start_lr = start_lr
self.start_lr = []
self.resume_step = num_iter
self.resume_no_optimze = resume_no_optimze
self.drop_begin = drop_begin
self.drop_iter = drop_iter # here drop_iter is half-life
self.drop_rate = drop_rate
for group in self.optimizer.param_groups:
self.start_lr.append(group["lr"])
self.step(self.num_iter)
def get_lr_warmup(self, num_iter, base_lr, warmup_iter) -> float:
return base_lr * num_iter / warmup_iter
def get_lr_stable(self, num_iter, base_lr):
return base_lr
def get_lr_drop(self, num_iter, base_lr):
if self.drop_iter < 0:
return base_lr
progress = (num_iter - self.drop_begin) / self.drop_iter
return base_lr * (self.drop_rate ** progress)
def get_lr(self, base_lr):
assert self.num_iter >= 0
if self.resume_step + self.resume_no_optimze > self.num_iter:
return self.get_lr_warmup(self.num_iter - self.resume_step, base_lr, self.resume_no_optimze)
if self.num_iter < self.warmup_iter:
return self.get_lr_warmup(self.num_iter, base_lr, self.warmup_iter)
if self.num_iter > self.drop_begin:
return self.get_lr_drop(self.num_iter, base_lr)
return self.get_lr_stable(self.num_iter, base_lr)
@property
def current_lr(self):
return self._current_lr
def step(self, num_iter=None) -> None:
if num_iter is None:
num_iter = self.num_iter + 1
self.num_iter = num_iter
self._current_lr = self.get_lr(self._start_lr)
for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
group["lr"] = self.get_lr(base_lr)
def state_dict(self):
return {
"_start_lr": self._start_lr,
"start_lr": self.start_lr,
"warmup_iter": self.warmup_iter,
"drop_begin": self.drop_begin,
"num_iter": self.num_iter,
}
def load_state_dict(self, state_dict):
self._start_lr = state_dict["_start_lr"]
self.start_lr = state_dict["start_lr"]
self.warmup_iter = state_dict["warmup_iter"]
self.drop_begin = state_dict["drop_begin"]
self.num_iter = state_dict["num_iter"]
self.step(self.num_iter)

View File

@ -1,5 +0,0 @@
from .log import logger
from .log import LogManager
from .object import allgather_objects
from .config import Config
from .gradient_shrink import gradient_shrink

View File

@ -1,159 +0,0 @@
import os
import random
import bmtrain as bmt
import numpy as np
class BitSet:
def __init__(self, size=1024**2):
self.size = size
self.bitset = np.zeros(self.size, dtype=bool)
def _ensure_capacity(self, num):
"""确保bitset有足够的容量来存储指定的数字"""
if num >= self.size:
# 扩展bitset大小
new_size = max(num + 1, self.size * 2)
new_bitset = np.zeros(new_size, dtype=bool)
new_bitset[: self.size] = self.bitset
self.bitset = new_bitset
self.size = new_size
bmt.print_rank("enlarge size to {}".format(self.size))
def add(self, num):
"""向bitset中添加一个数字"""
self._ensure_capacity(num)
self.bitset[num] = True
def remove(self, num):
"""从bitset中移除一个数字"""
if num < self.size:
self.bitset[num] = False
def contains(self, num):
"""检查bitset是否包含某个数字"""
return num < self.size and self.bitset[num]
def __contains__(self, num):
return self.contains(num)
def update(self, iterable_or_bitset):
"""使用可迭代对象或另一个BitSet中的元素更新当前bitset"""
if isinstance(iterable_or_bitset, BitSet):
# 如果参数是BitSet则使用numpy的向量化操作更新
self._ensure_capacity(iterable_or_bitset.size)
self.bitset[: iterable_or_bitset.size] |= iterable_or_bitset.bitset
else:
# 如果参数是可迭代对象,则遍历并添加每个元素
for num in iterable_or_bitset:
self.add(num)
def __sub__(self, other):
"""实现减法运算符使用numpy向量化操作来高效地创建一个新的bitset"""
# 创建一个新的bitset实例
result = BitSet(max(self.size, other.size))
# 使用numpy的向量化逻辑运算
result.bitset[: self.size] = self.bitset & ~other.bitset[: self.size]
return result
def __isub__(self, other):
"""实现就地减法运算符利用numpy向量化操作进行高效的元素移除"""
# 首先确保other的大小不超过当前bitset的大小
min_size = min(self.size, other.size)
# 使用numpy的向量化逻辑运算进行元素移除
self.bitset[:min_size] &= ~other.bitset[:min_size]
return self
def __str__(self):
"""返回bitset的字符串表示列出所有为真的位的索引"""
# 找出所有为真的位的索引
true_indices = np.where(self.bitset)[0]
# 将这些索引转换为字符串并用逗号分隔
indices_str = ", ".join(map(str, true_indices))
return f"BitSet({indices_str})"
def __len__(self):
"""返回bitset中为真的元素个数"""
return self.bitset.sum()
def capacity(self):
return self.size
def density(self):
return len(self) / self.size
def memory_usage(self):
"""返回bitset所占用的内存大小以KB、MB或GB为单位"""
bytes_usage = self.bitset.nbytes
if bytes_usage < 1024:
return f"{bytes_usage} B"
elif bytes_usage < 1024**2:
return f"{bytes_usage / 1024:.2f} KB"
elif bytes_usage < 1024**3:
return f"{bytes_usage / 1024**2:.2f} MB"
else:
return f"{bytes_usage / 1024**3:.2f} GB"
def to_list(self):
"""返回一个包含所有为真位索引的列表"""
return list(np.where(self.bitset)[0])
def save(self, filename):
"""将bitset保存到文件"""
def random_hash():
"""返回一个随机哈希值"""
return random.randint(0, 2**64 - 1)
filename_with_suffix = filename + ".{}.npy".format(random_hash())
dirname = os.path.dirname(filename_with_suffix)
os.makedirs(dirname, exist_ok=True)
np.save(filename_with_suffix, self.bitset)
return os.path.basename(filename_with_suffix) # 返回最后的名字不带前缀支持tranfer项目
@classmethod
def load(cls, filename_with_suffix):
"""从文件加载bitset并创建一个新的BitSet实例"""
bitset_array = np.load(filename_with_suffix)
bitset = cls(bitset_array.size)
bitset.bitset = bitset_array
return bitset
def bitset_diff(normal_set, bitset):
"""返回存在于normal_set中但不在bitset中的元素集合"""
ret = {elem for elem in normal_set if not bitset.contains(elem)}
return ret
if __name__ == "__main__":
# 示例使用
bitset1 = BitSet(1024)
bitset1.update([100, 200, 300, 1023])
bitset2 = BitSet(1024)
bitset2.update([100, 400, 1023])
result_bitset = bitset1 - bitset2
print(100 in result_bitset) # 应该输出False
print(200 in result_bitset) # 应该输出True
print(300 in result_bitset) # 应该输出True
print(1023 in result_bitset) # 应该输出False
bitset1 -= bitset2
print(result_bitset) # BitSet(200, 300)
print(bitset1) # BitSet(200, 300)
print(bitset2) # BitSet(100, 400, 1023)
bitsetlarge = BitSet(1024**3)
print(len(bitsetlarge), bitsetlarge.capacity(), bitsetlarge.density(), bitset1.density())
print("BitSet memory usage:", bitsetlarge.memory_usage())
print(bitset_diff({100, 200}, bitset2))
bitset1.update(bitset2)
bitsetlarge.add(52260134)
bitset2.update(bitsetlarge)
print(bitset1) # BitSet(100, 200, 300, 400, 1023)
print(bitset2) # BitSet(100, 400, 1023, 52260134)

View File

@ -1,531 +0,0 @@
# COPIED from exporter, for scaling project special use.
# Author: shengdinghu@gmail.com
import functools
import gc
import glob
import json
import multiprocessing as mp
import os
import re
from copy import deepcopy
from collections import defaultdict
from itertools import chain
import shutil
import threading
import time
import hashlib
from concurrent.futures import ThreadPoolExecutor
from typing import List
import bmtrain as bmt
from bmtrain.distributed import all_reduce, all_gather
import torch
from fm9g.utils import allgather_objects
from fm9g.utils.bitset import BitSet
from .log import logger
lock = threading.Lock()
def _save_artifacts(
model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt=None, final_save=False
):
"""
Export model artifacts. Mainly for the purpose of asynchronous saving.
"""
try:
# platform_cfg = get_platform_cfg()
raise ValueError("No platform_cfg")
except ValueError:
platform_cfg = None
export_model_dir = os.path.join(args.save, str(global_step))
else:
export_model_dir = (
args.save if final_save else platform_cfg.gen_export_ckpt_dir_for_step(args.save, global_step)
)
os.makedirs(export_model_dir, exist_ok=True)
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
bmt.print_rank(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
# model checkpoint
ckpt_path = os.path.join(export_model_dir, base_file_name + ".pt")
# opt 文件仅用于项目内续训,不需要导出为模型版本文件
opt_path = os.path.join(
export_model_dir,
args.save_name + ("-%d.rank-%d.opt" % (global_step, bmt.rank())),
)
if bmt.rank() == 0:
torch.save(model_state, ckpt_path)
bmt.print_rank(f"Save checkpoint successfully, ckpt file path: {ckpt_path}")
torch.save(opt_state, opt_path)
print(f"Save optimizer state successfully, opt file path: {opt_path}")
del model_state
del opt_state
# 保存统计量
if log_ckpt is not None:
bmt.print_rank("save log ckpt ...")
with open(os.path.join(export_model_dir, base_file_name + ".log_ckpt"), "w") as log_ckpt_file:
json.dump(log_ckpt, log_ckpt_file)
logger.info(f"Starting saving dataset state. ")
dataset_ckpt_path = os.path.join(export_model_dir, "dataset_ckpt")
os.makedirs(dataset_ckpt_path, exist_ok=True)
if bmt.config["tp_rank"] == 0:
p_dataset = os.path.join(dataset_ckpt_path, f"dataset_{bmt.rank()}.data")
dataloader.save_state_dict(p_dataset)
if bmt.rank() == 0:
# config 和 vocabs 和模型文件一起存储
model_config.save_pretrained(export_model_dir)
try:
tokenizer.save_pretrained(export_model_dir)
except:
bmt.print_rank("No save pretrained method for tokenizer")
shutil.copy(args.tokenizer_path, export_model_dir)
# 存储完所有文件后调用
if platform_cfg is not None:
platform_cfg.finalize_model_save(export_model_dir, base_file_name)
else:
bmt.print_rank("No platform_cfg, skip finalize_model_save, may be not have .success file")
logger.info(f"Successfully save model files: {os.listdir(export_model_dir)}")
# 每个进程都在export_model_dir写一个.save_done文件用于判断是否所有进程都保存完毕
# 直接用常规写文件方式
os.makedirs(os.path.join(export_model_dir, "save_status"), exist_ok=True)
with open(os.path.join(export_model_dir, f"save_status/{bmt.rank()}.save_done"), "w") as f:
f.write("done")
# 等待所有进程都保存完毕, 不能用synchronize
if bmt.rank() == 0:
while True:
if len(os.listdir(os.path.join(export_model_dir, "save_status"))) == bmt.world_size():
break
time.sleep(1)
bmt.print_rank(f"All saved! Rank 0 Begin to merge dataset ckpt to {dataset_ckpt_path}/dataset_.data")
merge_dataset_ckpts(export_model_dir, args.parallel_load_datastate//2)
else:
bmt.print_rank(f"rank-{bmt.rank()} done, wait for rank0 to merge dataset ckpt")
def export(
model: torch.nn.Module,
dataloader,
tokenizer,
optimizer: bmt.optim.AdamOffloadOptimizer,
global_step,
args,
log_ckpt=None,
final_save=False,
async_save=False,
):
"""
一次 ckpt 保存
/{args.save}/
job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
config.json
vocabs.txt
{args.save_name}-{global_step}.rank-0.opt
{args.save_name}-{global_step}.rank-n.opt
{args.save_name}-{global_step}.pt
{args.save_name}-{global_step}.data
{args.save_name}-{global_step}.data.json
{args.save_name}-{global_step}.success
{args.save_name}-{global_step}.log_ckpt
"""
bmt.synchronize()
model_state = bmt.store._save_to_rank0(model)
opt_state = deepcopy(optimizer.state_dict())
model_config = model.config
if async_save:
# Save artifacts asynchronously
save_proc = mp.Process(
target=_save_artifacts,
args=(model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt, final_save),
)
save_proc.start()
else:
_save_artifacts(
model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt, final_save
)
def load_model_ckpt(args, model):
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
if args.load.endswith(".pt"):
checkpoint_file = args.load
else:
# a directory
load_path = args.load
checkpoint_files = [file for file in os.listdir(load_path) if file.endswith(".pt")]
assert len(checkpoint_files) == 1, "None or multiple .pt found in {}".format(load_path)
checkpoint_file = os.path.join(load_path, checkpoint_files[0])
bmt.print_rank("args.load is not None, start to load checkpoints from" + checkpoint_file)
bmt.load(model, checkpoint_file)
return model
def _legacy_load_optimizer_ckpt(args, optimizer):
bmt.print_rank("Use legacy optimizer ckpt!")
if args.load.endswith(".pt"):
optimizer_path = os.path.dirname(os.path.dirname(args.load))
else:
optimizer_path = os.path.dirname(args.load)
bmt.print_rank(os.listdir(optimizer_path))
start = time.time()
bmt.print_rank(
"{}".format(
sum(
[
1
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
else 0
for i in os.listdir(optimizer_path)
]
)
)
)
if (
sum(
[
1
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
else 0
for i in os.listdir(optimizer_path)
]
)
== bmt.world_size()
):
pattern = "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank())
bmt.print_rank("Will load opt that matches pattern: {}".format(pattern))
for file_name in os.listdir(optimizer_path):
if file_name.find(pattern) != -1:
bmt.print_rank("start to load grad ckpt {}".format(file_name))
states = torch.load(os.path.join(optimizer_path, file_name))
optimizer.load_state_dict(states)
logger.info("load grad in {:.2f}s".format(time.time() - start))
return optimizer
def load_optimizer_ckpt(args, optimizer):
if args.load.endswith(".pt"):
optimizer_path = os.path.dirname(args.load)
else:
# a directory
optimizer_path = args.load
start = time.time()
opt_num = sum(
[1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(optimizer_path)]
)
bmt.print_rank(f"Opt file num: {opt_num}")
if opt_num == 0:
return _legacy_load_optimizer_ckpt(args, optimizer)
if opt_num == 0:
return _legacy_load_optimizer_ckpt(args, optimizer)
if opt_num == bmt.world_size():
file_name = os.path.join(
optimizer_path,
args.save_name + "-{}.rank-{}.opt".format(args.start_step, bmt.rank()),
)
if os.path.exists(file_name):
print("rank {} start to load grad ckpt {}".format(bmt.rank(), file_name))
states = torch.load(file_name)
optimizer.load_state_dict(states)
# optimizer_external_load_state_dict(optimizer, states, args.grad_ckpt_num)
logger.info("load grad in {:.2f}s".format(time.time() - start))
return optimizer
def load_dataloader_ckpt(args, mixed_dataset):
load_success = _load_distributed_dataset_state(args, mixed_dataset)
if not load_success:
logger.info("load from distributed data state dict fail, try to load from single data state_dict")
_load_dataloader_ckpt(args, mixed_dataset)
def _load_dataloader_ckpt(args, mixed_dataset):
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
if args.load.endswith(".pt"):
load_path = os.path.dirname(args.load)
else:
load_path = args.load
dataset_states_path = [file for file in os.listdir(load_path) if file.endswith(".data")]
assert len(dataset_states_path) == 1, "None or multiple .data found in {}, file list: {}".format(
load_path, dataset_states_path
)
dataset_states_path = dataset_states_path[0]
dataset_states_path = os.path.join(load_path, dataset_states_path)
bmt.print_rank("args.load is not None, start to load data ckpt from " + dataset_states_path)
dataset_states = torch.load(dataset_states_path)
missing = mixed_dataset.load_state_dict(dataset_states)
if len(missing) > 0:
bmt.print_rank("missing keys when loading dataset states: ", missing)
def load_trace_ckpt(args, dataloader):
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
return dataloader
def load_log_ckpt(args):
if args.load.endswith(".pt"):
load_path = os.path.dirname(args.load)
else:
load_path = args.load
log_ckpt_paths = [file for file in os.listdir(load_path) if file.endswith(".log_ckpt")]
assert len(log_ckpt_paths) <= 1, "Multiple .data found in {}".format(load_path)
if len(log_ckpt_paths) == 0:
bmt.print_rank("No log ckpt is found in {}".format(load_path))
return {}
log_ckpt_path = os.path.join(load_path, log_ckpt_paths[0])
bmt.print_rank("args.load is not None, start to load log ckpt from " + log_ckpt_path)
with open(log_ckpt_path, "r") as log_ckpt_file:
log_ckpt = json.load(log_ckpt_file)
return log_ckpt
def _load_distributed_dataset_state(args, mixed_dataset):
rank = bmt.rank()
logger.info(f"rank-{rank} -> [start]loading dataset states")
if args.load.endswith(".pt"):
load_dir = os.path.dirname(args.load)
else:
load_dir = args.load
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_ckpt/dataset_*.data")))
if len(p_datasets) == 0: # 向后兼容
bmt.print_rank("load_from_orginal_dataset_ckpt_folder")
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_*.data")))
all_state_dict = dict()
def load_and_aggregate(p_dataset):
"""Map func for loading and aggregating dataset states to all_state_dict"""
def new_key_init(all_state_dict, key, state_dict):
all_state_dict[key] = {}
for second_key in state_dict[key]:
if second_key == "used":
all_state_dict[key]["used"] = BitSet(1024)
all_state_dict[key]["used"].update(state_dict[key]["used"])
else:
all_state_dict[key][second_key] = state_dict[key][second_key]
def load_state_dict_with_npy(p_dataset):
state_dict = torch.load(p_dataset)
for key in state_dict:
if isinstance(state_dict[key]["used"], str):
if os.path.basename(state_dict[key]["used"]) == state_dict[key]["used"]: # 如果只有文件名
state_dict[key]["used"] = BitSet.load(os.path.join(load_dir, "dataset_ckpt", os.path.basename(state_dict[key]["used"])))
else: # 如果是完整路径(向后兼容)
state_dict[key]["used"] = BitSet.load(state_dict[key]["used"])
return state_dict
print(f"Loading {p_dataset}...")
state_dict = load_state_dict_with_npy(p_dataset)
dataset_locks = {}
for key in state_dict.keys():
if key not in dataset_locks:
with lock:
if key not in dataset_locks:
dataset_locks[key] = threading.Lock()
with dataset_locks[key]:
if key in all_state_dict:
if all_state_dict[key]["exhausted"]:
continue
elif state_dict[key]["exhausted"]:
all_state_dict[key]["exhausted"] = True
all_state_dict[key]["used"] = BitSet(1024)
else:
all_state_dict[key]["used"].update(state_dict[key]["used"])
else:
new_key_init(all_state_dict, key, state_dict)
bmt.print_rank(f"[done]loaded dataset states: {p_dataset}")
if p_datasets:
rank = bmt.rank()
lst_time = time.time()
if rank == 0:
with ThreadPoolExecutor(max_workers=args.parallel_load_datastate) as executor:
executor.map(load_and_aggregate, p_datasets)
logger.info(
f"rank-{rank} -> load dataset from {len(p_datasets)} .data files. Time: {time.time() - lst_time:.2f}s"
)
# Broadcast the tensor to other process
lst_time = time.time()
all_state_dict = bmt.store.broadcast_object(all_state_dict, comm=bmt.config["comm"], src=0)
logger.info(f"rank-{rank} -> broadcast dataset from rank-0 to other ranks. Time: {time.time() - lst_time:.2f}s")
lst_time = time.time()
missing_keys = mixed_dataset.load_state_dict(all_state_dict)
logger.info(f"rank-{rank} -> load mixed dataset state dict. Time: {time.time() - lst_time:.2f}s")
if missing_keys:
logger.info(
f"rank-{rank} -> load dataset from {len(p_datasets)} .data files with {os.path.join(load_dir, 'dataset_ckpt', 'dataset*.pt')} : {p_datasets}, missing tasks: {missing_keys}"
)
else:
state_info = {
k: {
"ave_tokens": s.get("ave_tokens", -1),
"set_info": "mem:{}|density:{:.4f}|len:{}".format(
s["used"].memory_usage(), s["used"].density(), len(s["used"])
),
}
for k, s in all_state_dict.items()
}
logger.info(
f"rank-{rank} -> load dataset from {len(p_datasets)} files with {os.path.join(load_dir, 'dataset*.pt')}. Info: {state_info}"
)
return True
else:
logger.info(f"No dataset*.data found. p_datasets: {p_datasets}")
return False
import json
import os
from collections import OrderedDict
def flatten_stats(stats, parent_key="", separator="/"):
items = []
for key, value in stats.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict):
items.extend(flatten_stats(value, new_key, separator).items())
if isinstance(value, list):
items.append((new_key, json.dumps(value)))
else:
items.append((new_key, value))
return OrderedDict(items)
def save_every_step_stats(stats, path):
flattened_stats = flatten_stats(stats)
os.makedirs(os.path.join(path, "train_stats/"), exist_ok=True)
# Function to get the current file ID and size
def get_current_file_id_and_size(path):
id = 0
while True:
file_path = os.path.join(path, f"train_stats/{id}.jsonl")
if not os.path.exists(file_path):
return id, 0
else:
size = os.path.getsize(file_path)
if size > 19 * 1024 * 1024: # Size in bytes (20 MB)
id += 1
else:
return id, size
# Get the current file id and its size
current_id, file_size = get_current_file_id_and_size(path)
# Generate the file path
file_path = os.path.join(path, f"train_stats/{current_id}.jsonl")
# Write the flattened stats to the file
with open(file_path, "a") as json_file:
json_file.write(json.dumps(flattened_stats) + "\n")
def merge_dataset_ckpts(load_dir, parallel_load_datastate):
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_ckpt/dataset_*.data")))
p_datasets= [x for x in p_datasets if "dataset_ckpt/dataset_.data" not in x]
bmt.print_rank(f"Files before merge (total num {len(p_datasets)}): {p_datasets}")
all_state_dict = dict()
def load_and_aggregate(p_dataset):
"""Map func for loading and aggregating dataset states to all_state_dict"""
def new_key_init(all_state_dict, key, state_dict):
all_state_dict[key] = {}
for second_key in state_dict[key]:
if second_key == "used":
all_state_dict[key]["used"] = BitSet(1024)
all_state_dict[key]["used"].update(state_dict[key]["used"])
else:
all_state_dict[key][second_key] = state_dict[key][second_key]
def load_state_dict_with_npy(p_dataset):
state_dict = torch.load(p_dataset)
for key in state_dict:
if isinstance(state_dict[key]["used"], str):
state_dict[key]["used"] = BitSet.load(os.path.join(load_dir, "dataset_ckpt", os.path.basename(state_dict[key]["used"])))
return state_dict
print(f"Loading {p_dataset}...")
state_dict = load_state_dict_with_npy(p_dataset)
dataset_locks = {}
for key in state_dict.keys():
if key not in dataset_locks:
with lock:
if key not in dataset_locks:
dataset_locks[key] = threading.Lock()
with dataset_locks[key]:
if key in all_state_dict:
if all_state_dict[key]["exhausted"]:
continue
elif state_dict[key]["exhausted"]:
all_state_dict[key]["exhausted"] = True
all_state_dict[key]["used"] = BitSet(1024)
else:
all_state_dict[key]["used"].update(state_dict[key]["used"])
else:
new_key_init(all_state_dict, key, state_dict)
del state_dict
print(f"Merged {p_dataset}...")
if p_datasets:
# with ThreadPoolExecutor(max_workers=args.parallel_load_datastate) as executor:
with ThreadPoolExecutor(max_workers=parallel_load_datastate) as executor: # smaller than normal load to avoid OOM
executor.map(load_and_aggregate, p_datasets)
# load_and_aggregate(p_datasets[0])
# Broadcast the tensor to other process
save_path = os.path.join(load_dir, "dataset_ckpt", "dataset_.data")
# save_path = os.path.join("dataset_.data")
for key in all_state_dict:
npy_path = all_state_dict[key]["used"].save(save_path)
all_state_dict[key]["used"] = npy_path
bmt.print_rank(f"All state_dict after merge {all_state_dict}")
torch.save(all_state_dict, save_path)
# Find all files that match the pattern
files_to_remove = glob.glob(os.path.join(load_dir, "dataset_ckpt", "dataset_*.data*"))
# Remove the files
for file in files_to_remove:
if "dataset_.data" not in file:
os.remove(file)
files_after_merge = os.listdir(os.path.join(load_dir, "dataset_ckpt"))
bmt.print_rank(f"Files after merge: {files_after_merge}")

View File

@ -1,16 +0,0 @@
import torch
class OpGradientShrink(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, alpha: float):
ctx.alpha = alpha
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output * ctx.alpha, None
def gradient_shrink(x: torch.Tensor, alpha: float = 0.1):
return OpGradientShrink.apply(x, alpha)

View File

@ -1,117 +0,0 @@
import datetime
import json
import logging
import os
import sys
import time as time_
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union
def _get_logger():
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
log = logging.getLogger("__name__")
log.setLevel(log_level)
log.propagate = False
node_name = os.getenv("NODE_NAME", "jeeves-hpc-gpu00")
fmt = f"[%(levelname)s][%(asctime)s][{node_name}][%(filename)s:%(lineno)d:%(process)d] - %(message)s"
formatter = logging.Formatter(fmt, datefmt="%Y-%m-%d %H:%M:%S")
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(log_level)
handler.setFormatter(formatter)
log.addHandler(handler)
return log
# 日志句柄
logger = _get_logger()
class LogManager:
def __init__(self, path: str):
if not os.path.exists(path):
os.makedirs(path)
self.path = path
now = self.get_log_time()
latest_log: Union[Dict[str, Any], None] = None
for _ in range(15):
log_name = self.get_log_name(now)
if os.path.exists(log_name):
with open(log_name, "r") as flog:
lines = flog.readlines()
if lines:
latest_log = json.loads(lines[-1])
break
now -= datetime.timedelta(days=1)
if latest_log is None:
self.global_token_pass = 0
else:
self.global_token_pass = latest_log["token pass"]
def get_log_time(self) -> datetime.datetime:
return datetime.datetime.utcnow() + datetime.timedelta(hours=16)
def get_log_name(self, now: Optional[datetime.datetime] = None):
if now is None:
now = self.get_log_time()
return os.path.join(self.path, "log.%s.txt" % now.strftime("%Y%m%d"))
def write(
self,
time: float,
iteration: int,
loss: float,
lr: float,
lr_scale: float,
time_usage: Dict[str, float],
mem_usage: Dict[str, Tuple[float, float]],
avg_time: float,
token_max: float,
token_pass: float,
throughout: float,
grad_norm: float,
mask_max: float,
num_gpus: int,
task_loss: Dict[str, float],
model_inspect: Optional[Any] = None,
):
with open(self.get_log_name(), "a") as fp:
while True:
try:
ret = {
"time": time,
"iter": iteration,
"loss": loss,
"lr": lr,
"lr scale": int(lr_scale),
"time usage": time_usage,
"mem usage": mem_usage,
"avg time (s)": avg_time,
"token/max": token_max,
"token pass": token_pass + self.global_token_pass,
"throughout (token/s)": throughout,
"grad_norm": grad_norm,
"mask/max": mask_max,
"num_gpus": num_gpus,
"task_loss": task_loss,
}
if model_inspect is not None:
ret["model_inspect"] = model_inspect
print(ret)
fp.write(json.dumps(ret) + "\n")
break
except Exception as e:
print(e)
print("Error: writing info list!")
time_.sleep(10)

View File

@ -1,29 +0,0 @@
import pickle
import bmtrain as bmt
import torch
def allgather_objects(obj):
if bmt.world_size() == 1:
return [obj]
with torch.no_grad():
data_bytes: bytes = pickle.dumps(obj)
data_length: int = len(data_bytes)
gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long)
gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu()
max_data_length = gathered_length.max().item()
gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda")
byte_storage = torch.ByteStorage.from_buffer(data_bytes)
gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage)
gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu()
ret = []
for i in range(gathered_data.size(0)):
data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes()
ret.append(pickle.loads(data_bytes))
return ret

View File

@ -1,72 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
#
# @author: hsd9026 <shengdinghu@gmail.com>
# @date: 2023/07/07
#
import copy
from .log import logger
def num_parameters(model):
"""Return the number of parameters of a model"""
total_params = 0
for param_name, param in model.state_dict().items():
# print(param_name, param.numel())
total_params += param.numel()
return total_params
def num_non_embedding_parameters(model):
"""Return the number of parameters of a model"""
total_params = 0
for param_name, param in model.state_dict().items():
if ("embed" in param_name) or ("lm_head" in param_name):
continue
# print(param_name, param.numel())
total_params += param.numel()
return total_params
def estimate_parameters(config):
"""Estimate the number of parameters of a model given its config, should be equal to `num_parameters(model)`"""
# embedding parameters
embedding_params = config.vocab_size * config.dim_model
self_attn_params = 4 * config.dim_model * config.dim_head * config.num_heads * config.num_layers
ff_params = 3 * config.dim_model * config.dim_ff * config.num_layers
layernorm_parameters = 2 * config.dim_model * config.num_layers
output_layernorm_parameters = config.dim_model
positional_bias = 32
total_params = (
embedding_params
+ self_attn_params
+ ff_params
+ layernorm_parameters
+ output_layernorm_parameters
+ positional_bias
)
params_without_embeddings = total_params - embedding_params
return total_params, params_without_embeddings
def get_flops_per_token(config):
"""An estimated version of pfdays per token, i.e., the 6N in equation: Computation = 6 N * D."""
_, N = estimate_parameters(config)
logger.info(">>>>>> pfdays_per_token >>>> {:,.0f}".format(N))
# evaluating a forward pass
C_forward = 6 * N # + 2 * n_layer * n_ctx * d_model
# unit = 10**15 * 3600 * 24
# C_forward = C_forward / unit
return C_forward

View File

@ -1,34 +0,0 @@
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
def van_der_corput(n, base=2):
"""Generate the n-th value in the Van der Corput sequence."""
vdc, denom = 0, 1
while n:
denom *= base
n, remainder = divmod(n, base)
vdc += remainder / denom
return vdc
def van_der_corput_sampling_gen(vdc_values):
"""Generator function for sampling indices based on weights using the Van der Corput sequence."""
def gen(weights, vdc_value):
cdf = np.cumsum(weights)
sample = np.searchsorted(cdf, vdc_value)
return sample
sample_index = 0
# Pre-generate Van der Corput sequence
max_samples = 100000 # or any number that you find suitable
while True:
# Generate the next value in the Van der Corput sequence
vdc_value = vdc_values[sample_index % max_samples]
# Generate a sample index based on the Van der Corput value and the CDF
yield partial(gen, vdc_value=vdc_value)
sample_index += 1

3
FM_9G/readme.txt Normal file
View File

@ -0,0 +1,3 @@
cd FM9G-V
pip install -r requirements.txt
python chat_model.py

21
FM_9G/requirements.txt Normal file
View File

@ -0,0 +1,21 @@
torch==2.0.1
torchvision==0.15.2
transformers==4.31.0
tokenizers>=0.12.1,<0.14
sentencepiece==0.1.99
shortuuid
peft==0.4.0
bitsandbytes==0.41.0
pydantic<2,>=1
markdown2[all]
numpy
scikit-learn==1.2.2
gradio==3.35.2
gradio_client==0.2.9
requests
httpx==0.24.0
uvicorn
fastapi
einops==0.6.1
einops-exts==0.0.4
timm==0.9.8

BIN
FM_9G/test.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 128 KiB

BIN
FM_9G/vis_fm9g/.DS_Store vendored Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,7 @@
[
{ "data_source_name": "laion_coco", "data_source_weight": 200 },
{ "data_source_name": "cc12m", "data_source_weight": 15 },
{ "data_source_name": "cc3m", "data_source_weight": 30 },
{ "data_source_name": "coco", "data_source_weight": 5 },
{ "data_source_name": "vg", "data_source_weight": 5 }
]

View File

@ -0,0 +1,3 @@
[
{ "data_source_name": "pretrain_eval_eval", "data_source_weight": 1 }
]

View File

@ -0,0 +1,3 @@
[
{ "data_source_name": "pretrain_eval_train", "data_source_weight": 1 }
]

View File

@ -0,0 +1,33 @@
{
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 16,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 5e-5,
"betas": [
0.9,
0.98
],
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 1e-6,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 500
}
},
"fp16": {
"enabled": true,
"initial_scale_power": 10,
"auto_cast": true
},
"zero_optimization": {
"stage": 2
},
"steps_per_print": 50,
"gradient_clipping": 1.0
}

View File

@ -0,0 +1,33 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 16,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5,
"betas": [
0.9,
0.98
],
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 1e-6,
"warmup_max_lr": 1e-5,
"warmup_num_steps": 500
}
},
"fp16": {
"enabled": true,
"initial_scale_power": 10,
"auto_cast": true
},
"zero_optimization": {
"stage": 2
},
"steps_per_print": 50,
"gradient_clipping": 1.0
}

View File

@ -3,25 +3,14 @@
"dropout_p": 0.0,
"eps": 1e-05,
"half": true,
"half_type": "bf16",
"use_flash_attn": true,
"flash_attn_mask_shape": "2d",
"dim_model": 4096,
"dim_ff": 14336,
"dim_ff": 11008,
"dim_head": 128,
"num_heads": 32,
"num_kv_heads": 32,
"num_layers": 32,
"activate_fn": "silu",
"init_std": 0.10,
"scale": false,
"scale_emb": 12,
"scale_depth": -1,
"model_type": "fm9g",
"architectures": [
"FM9GForCausalLM"
],
"qk_norm": false,
"tie_lm_head": false,
"ffn_gated": true
"scale": false
}

View File

@ -119687,10 +119687,10 @@
"𠳐"
"𥻗"
"𬉼"
"<|im_start|>"
"<|im_end|>"
"<pad_2>"
"<pad_3>"
"<pad_4>"
"<pad_5>"
"<pad_6>"
"<image>"
"</image>"
"<ref>"
"</ref>"
"<box>"
"</box>"
"<quad>"

View File

@ -0,0 +1,231 @@
import io
import json
import logging
import random
import numpy
import base64
import os.path as op
import torch.utils.data as torch_data
from PIL import Image
from typing import List, Iterator
from muffin.data.tsv_file import TSVFile
from muffin.data.data_processors import register_data_processor
from vis_fm9g.dataset.itembuilder import ItemBuilder
logger = logging.getLogger(__file__)
class MultimodalQADataset(torch_data.Dataset):
def __init__(self, qa_file, question_process):
'''
qa_file: jsonl file that each line is a dict like {
'image': b64img,
'question': question_text
}
'''
super().__init__()
self.qa_file = qa_file
self.qa_data = [json.loads(line) for line in open(self.qa_file)]
if isinstance(self.qa_data[0], list):
self.qa_data = self.qa_data[0] # unwrap one-line json question file
self.question_process = question_process
def __getitem__(self, index):
item = self.qa_data[index]
img_b64 = item['image']
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert('RGB')
raw_question = item['question']
question_text = self.question_process(raw_question)
return {
'image': image,
'raw_question': raw_question,
'question': question_text
}
def __len__(self):
return len(self.qa_data)
class SingleDataSourceDataset(torch_data.Dataset):
def __init__(self, ds_name, item_builder: ItemBuilder, data_dir, tsv_filenames: List[str], intent='sft') -> None:
super().__init__()
self.data_dir = data_dir
self.filenames = tsv_filenames
self.ds_name = ds_name
self.sizes = []
for filename in self.filenames:
try:
size = int(filename[:-4].split('-')[-1])
except:
raise ValueError(f'TSV Data File {filename} is not valid, last component separated by `-` must be the number of sample in this file')
self.sizes.append(size)
self.file_border_index = []
self.prepare_border_index()
self.item_builder = item_builder
self.files = self.filenames[:]
self.intent = intent
def prepare_border_index(self):
self.file_border_index = [0]
temp_sum = 0
for size in self.sizes:
temp_sum += size
self.file_border_index.append(temp_sum)
def get_file_idx_and_row_idx(self, index):
found = False
file_idx = -1
for border_idx, border in enumerate(self.file_border_index):
if index < border:
file_idx = border_idx - 1
found = True
break
if not found:
raise ValueError(f'Index {index} out of range for {self.size_sum} border markers')
offset = self.file_border_index[file_idx]
row_idx = index - offset
return file_idx, row_idx
def __len__(self):
return self.file_border_index[-1]
def __getitem__(self, index):
file_idx, row_idx = self.get_file_idx_and_row_idx(index)
try:
sample = self.fetch_sample(file_idx, row_idx)
item = self.item_builder.build_item(sample)
except:
logger.warning(f"data fetch error")
return self.__getitem__(random.randint(0, len(self)))
return item
def fetch_sample(self, file_idx, row_idx):
file = self.files[file_idx]
if isinstance(file, str):
self.prepare_file(file_idx)
file = self.files[file_idx]
assert isinstance(file, TSVFile), f'Expecting TSVFile but get {file} as {type(file)}'
# tsv line as tuple
sample = file[row_idx]
ds_name, *values = sample
# data dict
sample = register_data_processor[self.ds_name](*values, intent=self.intent)
if row_idx + 1 == len(file):
del file
self.files[file_idx] = self.filenames[file_idx]
return sample
def prepare_file(self, idx):
filename = self.filenames[idx]
file = TSVFile(op.join(self.data_dir, filename))
self.files[idx] = file
class IterableSingleDataSourceDataset(torch_data.IterableDataset):
def __init__(self) -> None:
super().__init__()
raise NotImplemented
class MultiDataSourceDataset(torch_data.Dataset):
def __init__(self, data_sources: List[SingleDataSourceDataset], data_source_weights: List[int]):
super().__init__()
self.ds_list = data_sources
self.sum_weight = sum(data_source_weights)
self.ds_weights = data_source_weights
for weight in self.ds_weights:
assert isinstance(weight, int), 'weight must be integer'
self.offset2ds = {}
self.offset2wt = {}
self.offset2pd = {}
self.prepare_offset2ds()
ds_loops = []
for ds, wt in zip(self.ds_list, self.ds_weights):
ds_loop = len(ds) // wt
ds_loops.append(ds_loop)
max_loop = max(ds_loops)
self.size = max_loop * self.sum_weight
def prepare_offset2ds(self):
offset = 0
for ds, weight in zip(self.ds_list, self.ds_weights):
pd = offset
for _ in range(weight):
self.offset2ds[offset] = ds
self.offset2wt[offset] = weight
self.offset2pd[offset] = pd
offset += 1
def __getitem__(self, index):
n_loop = index // self.sum_weight
offset = index % self.sum_weight
ds = self.offset2ds[offset]
ds_inner_idx = n_loop * self.offset2wt[offset] + offset - self.offset2pd[offset]
ds_inner_idx = ds_inner_idx % len(ds)
return ds[ds_inner_idx]
def __len__(self):
return self.size
class IterableMultiDataSourceDataset(torch_data.IterableDataset):
def __init__(self, data_sources, data_source_weights):
super().__init__()
self.ds_list = data_sources
sum_weight = sum(data_source_weights)
self.ds_weights = [x / sum_weight for x in data_source_weights]
self.ds_consumption = []
self.ds_sizes = [len(ds) for ds in self.ds_list]
def __next__(self):
ds_idx = numpy.random.choice(range(len(self.ds_list)), 1, p=self.ds_weights)[0]
data_source = self.ds_list[ds_idx]
self.ds_consumption[ds_idx] += 1
if self.ds_consumption[ds_idx] % self.ds_sizes[ds_idx] == 0:
self.report_consumption()
sample = next(data_source)
return sample
def __iter__(self) -> Iterator:
return self
def __len__(self):
return sum(self.ds_sizes)
def report_consumption(self):
for ds, consumption, size in zip(self.ds_list, self.ds_consumption, self.ds_sizes):
print(f'Data {ds} consumption: {consumption / size:.2f} epoch', flush=True)

View File

@ -0,0 +1,309 @@
import io
import json
from typing import Dict, Tuple, List, Any
import torch
import pandas as pd
import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.data import default_collate
from utils.logger import init_logger
from vis_fm9g.dataset.utils import convert_data_to_id
from vis_fm9g.dataset.utils import convert_conversation_data_to_id
from vis_fm9g.dataset.utils import pad
import random
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.utils.constants import usr_indicator, bot_indicator
from vis_fm9g.dataset.prompts import caption_zh, caption_en
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
logger = init_logger()
def is_contain_chinese(check_str):
"""
判断字符串中是否包含中文
:param check_str: {str} 需要检测的字符串
:return: {bool} 包含返回True 不包含返回False
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
def maybe_select_text(raw_text):
candidates = raw_text.split('<cap_sep>')
return random.choice(candidates)
def maybe_parse_json(raw_text: str):
# VG raw
if raw_text.startswith('[{') and raw_text.endswith('}]'):
try:
data = json.loads(raw_text)
text_list = [x['phrase'] for x in data if x['height'] > 160 and x['width'] > 160]
if len(text_list) == 0:
return max(data, key=lambda x: len(x['phrase'].split()))['phrase']
else:
return random.choice(text_list)
except:
return raw_text
else:
return raw_text
def clean_text(raw_text):
text = raw_text.replace('<PERSON>', '')
text = maybe_parse_json(maybe_select_text(text))
return text
def check_text_valid(raw_text):
if pd.isna(raw_text):
return False
if not is_contain_chinese(raw_text) and len(raw_text.split()) <= 3:
return False
if '<img' in raw_text or '<a href' in raw_text:
return False
return True
def get_image_placeholder(tokenizer, query_len, use_im_start_end=False):
if use_im_start_end:
return tokenizer.im_start + tokenizer.unk_token * query_len + tokenizer.im_end
else:
return tokenizer.unk_token * query_len
class ItemBuilder():
def __init__(self, transform=None):
self.transform = transform
def build_item(self, data):
if self.transform is not None:
return self.transform(data)
return data
# --------------------- FM9G ---------------------
class FM9GBuilder(ItemBuilder):
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, skip_overlength=False):
super().__init__(transform)
self.tokenizer = tokenizer
self.max_len = max_len
self.skip_overlength = skip_overlength
def convert_data(self, inp_dicts: List[Dict], raw_data):
res = []
for inp_dict in inp_dicts:
input_ids, context = convert_data_to_id(self.tokenizer, data=inp_dict)
if len(input_ids) > self.max_len:
if self.skip_overlength:
if random.random() > 0.95:
logger.warn(f"overlength={len(input_ids)}, raw_inp={inp_dict}, skip data")
else:
logger.warn(f"overlength={len(input_ids)}, skip data")
continue
input_ids = input_ids[: self.max_len]
context = context[: self.max_len]
res.append({
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
'context': torch.from_numpy(context).unsqueeze(0),
'raw_data': raw_data,
})
return res
def convert_conversation_data(self, conversation_list: List[List]):
res = []
for conversation in conversation_list:
input_ids, context, raw = convert_conversation_data_to_id(self.tokenizer, data=conversation, predict_roles={bot_indicator})
if len(input_ids) > self.max_len:
if self.skip_overlength:
if random.random() > 0.95:
logger.warn(f"overlength={len(input_ids)}, raw_inp={conversation}, skip data")
else:
logger.warn(f"overlength={len(input_ids)}, skip data")
continue
input_ids = input_ids[: self.max_len]
context = context[: self.max_len]
res.append({
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
'context': torch.from_numpy(context).unsqueeze(0),
'raw_data': raw,
})
return res
def build_image_bound(self, res, images):
return_res = []
if isinstance(images, List) and len(images) > 0:
images = torch.stack(images)
for r in res:
# r['input_ids'] (1, len)
image_start_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_start])[0]
# 跳过 im_start
image_start_tokens += 1
image_end_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_end])[0]
if len(image_start_tokens) != len(image_end_tokens) or len(image_start_tokens) > len(images):
continue
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
r['pixel_values'] = images[:len(image_start_tokens)]
r['image_bound'] = image_bound
return_res.append(r)
return return_res
def build_item(self, data):
NotImplementedError("build_item is not implemented.")
class FM9GImageTextBuilder(FM9GBuilder):
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, query_len=64, min_resolution=0, skip_overlength=False):
super().__init__(tokenizer, max_len, transform, skip_overlength)
self.query_len = query_len
self.min_resolution = min_resolution
def build_item(self, data):
text = data['conversations']
image = data['image']
source = data.get('metainfo', {}).get('origin_dataset', 'unk')
image = self.transform(image)
raw_data = {'text': text}
image_placeholder = get_image_placeholder(self.tokenizer, self.query_len, use_im_start_end=True)
messages = []
for i in range(len(text)):
role = text[i]['from']
role = usr_indicator if role == 'human' else bot_indicator
value = self.tokenizer.escape(text[i]['value'])
if '<image>' in value:
value = value.replace('<image>', image_placeholder)
messages.append((role, value))
res = self.convert_conversation_data([messages])
self.build_image_bound(res, images=[image])
for r in res:
r['source'] = source
return res[0]
class FM9GCollater:
def __init__(self, tokenizer: FM9GTokenizer, max_len: int, unpad: bool = False, unilm: bool = False):
self.tokenizer = tokenizer
self._max_length = max_len
self._unpad = unpad
self._unilm = unilm
self.pad_keys = ['input_ids', 'context']
def __call__(self, batch):
batch_cnt = len(batch)
if self._unpad: # for flash_attention cuda
max_length = self._max_length * batch_cnt
batch_size = 1
else:
max_length = self._max_length
batch_size = batch_cnt
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
context_origin = np.zeros((batch_size, max_length), dtype=np.int8)
context = np.zeros((batch_size, max_length), dtype=np.int8)
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
spans = np.zeros((batch_size, max_length), dtype=np.int32)
length = np.zeros((batch_size,), dtype=np.int32)
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
if self._unpad: # for flash_attention cuda force batch_size=1
flatten_input_ids = np.concatenate([batch[i]['input_ids'][0] for i in range(batch_cnt)], axis=0)
flatten_context = np.concatenate([batch[i]['context'][0] for i in range(batch_cnt)], axis=0)
instance_length = flatten_input_ids.shape[0]
inputs[0, : instance_length] = flatten_input_ids
context_origin[0, : instance_length] = flatten_context
length[0] = instance_length
if self._unilm:
context[0, : instance_length] = flatten_context
# flatten batch
_spans = [list(np.cumsum([batch[i]['input_ids'][0].shape[0] for i in range(batch_cnt)]))]
else:
for i in range(batch_cnt):
instance_length = batch[i]['input_ids'][0].shape[0]
inputs[i, :instance_length] = batch[i]['input_ids'][0]
context_origin[i, : instance_length] = batch[i]['context'][0]
length[i] = instance_length
if self._unilm:
context[i, :instance_length] = batch[i]['context'][0]
_spans = [[batch[i]['input_ids'][0].shape[0]] for i in range(batch_cnt)]
# cu_seqlens 和 max_seqlen 在 flash_attention cuda 时需要
if _spans[0][-1] != max_length:
cu_seqlens = np.array([0] + _spans[0] + [max_length], dtype=np.int32)
else:
cu_seqlens = np.array([0] + _spans[0], dtype=np.int32)
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
raw_data_list: List[Any] = [batch[i]['raw_data'] for i in range(batch_cnt)]
source_list: List[Any] = [batch[i].get('source', 'unk') for i in range(batch_cnt)]
for i in range(batch_size):
instance_length = length[i]
span_begin = 0
for span_id, span_end in enumerate(_spans[i]):
spans[i, span_begin: span_end] = span_id
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
span_begin = span_end
for j in range(instance_length):
idx = inputs[i][j]
if j > 1:
if context_origin[i][j] == 0:
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
tgt[i, j - 1] = idx
if context_origin[i][j] == 1 and context_origin[i][j-1] == 0:
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
tgt[i, j - 1] = self.tokenizer.eos_id
data = {}
# image
if 'pixel_values' in batch[0]:
if self._unpad:
data['pixel_values'] = [torch.vstack([i['pixel_values'] for i in batch])]
else:
data['pixel_values'] = [i['pixel_values'] for i in batch]
# image_bound
if 'image_bound' in batch[0]:
if self._unpad:
image_bounds = []
for i in range(batch_cnt):
offset = _spans[0][i-1] if i > 0 else 0
image_bounds.append(batch[i]['image_bound'] + offset)
data['image_bound'] = [torch.vstack(image_bounds)]
else:
data['image_bound'] = [i['image_bound'] for i in batch]
data['input_ids'] = torch.from_numpy(inputs)
data['context'] = torch.from_numpy(context) > 0
data['length'] = torch.from_numpy(length)
data['spans'] = torch.from_numpy(spans)
data['cu_seqlens'] = torch.from_numpy(cu_seqlens)
data['max_seqlen'] = max_seqlen
data['position_ids'] = torch.from_numpy(position_ids)
data['target'] = torch.from_numpy(tgt)
data['raw_data'] = raw_data_list
data['source'] = source_list
return data

View File

@ -0,0 +1,24 @@
caption_en = [
'Describe the image concisely',
'Provide a brief description of the given image',
'Offer a succinct explanation of the picture presented',
'Summarize the visual content of the image',
'Share a conciseinter pretation of the image provided',
'Present a compact description of the photos key features',
'Relay a brief and clear account of the picture shown',
'Render a clear and concise summary of the photo',
'Write a terse but informative summary of the picture',
'Create a compact narrative representing the image presented',
]
caption_zh = [
'简明扼要地描述图像',
'提供给定图像的简短描述',
'对所示的图片进行简要的解释',
'总结图像的视觉内容',
'对所提供的图像进行简要的解释',
'简明扼要并清楚地说明所示图片',
'对这张照片作一个简明扼要的总结',
'写一篇简洁但内容丰富的图片摘要',
'创造一个紧凑的叙事来代表所呈现的图像',
]

View File

@ -0,0 +1,170 @@
import importlib.machinery
import importlib.util
import types
from typing import Any, Set
from typing import Dict
from typing import List
from typing import Union
import torch
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import BatchSampler
from typing_extensions import TypedDict
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
from vis_fm9g.utils.constants import SYSTEM
FM9GInputType = Union[str, Dict[str, "FM9GInputType"]]
class _TransformFuncDict(TypedDict):
loader: importlib.machinery.SourceFileLoader
module: types.ModuleType
last_m: float
class FM9GBatch(TypedDict):
inputs: NDArray[np.int32]
length: NDArray[np.int32]
context: NDArray[np.bool_]
sample_ids: NDArray[np.int32]
spans: NDArray[np.int32]
target: NDArray[np.int32]
task_ids: NDArray[np.int32]
task_names: List[str]
raw_data: List[Any]
def convert_data_to_id(tokenizer: FM9GTokenizer, data: Any):
"""
data: {
'input': xxx,
'output': xxx
}
"""
input_ids = tokenizer.encode(data["input"])
output_ids = tokenizer.encode(data["output"])
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
ids = np.array(ids, dtype=np.int32)
context = np.zeros((ids.shape[0],), dtype=np.int8)
# context = 0 需要 target
context[: len(input_ids) + 1] = 1
return ids, context
def convert_conversation_data_to_id(tokenizer: FM9GTokenizer, data: Any, predict_roles: Set):
"""
predict_roles: {'<AI>'}
data: [
('<用户>', xxxx),
('<AI>', xxxx)
]
"""
assert (set([i[0] for i in data]) & predict_roles)
if SYSTEM:
system = tokenizer.bos_token + SYSTEM + '\n'
else:
system = tokenizer.bos_token
sys_idx = tokenizer.encode(system)
ret = system
input_ids = [sys_idx] if sys_idx else []
context = [np.ones((len(sys_idx),), dtype=np.int8)]
for idx, (role, message) in enumerate(data):
prefix = role
# 最后一句加上 eos
if idx == len(data)-1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)
message_ids = tokenizer.encode(message)
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role in predict_roles:
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
ret += (prefix + message)
ids = np.hstack(input_ids)
context = np.hstack(context)
return ids, context, ret
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
items = []
if isinstance(orig_items[0][key], list):
assert isinstance(orig_items[0][key][0], torch.Tensor)
for it in orig_items:
for tr in it[key]:
items.append({key: tr})
else:
assert isinstance(orig_items[0][key], torch.Tensor)
items = orig_items
batch_size = len(items)
shape = items[0][key].shape
dim = len(shape)
assert dim <= 3
if max_length is None:
max_length = 0
max_length = max(max_length, max(item[key].shape[-1] for item in items))
min_length = min(item[key].shape[-1] for item in items)
dtype = items[0][key].dtype
if dim == 1:
return torch.cat([item[key] for item in items], dim=0)
elif dim == 2:
if max_length == min_length:
return torch.cat([item[key] for item in items], dim=0)
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
for i, item in enumerate(items):
if dim == 2:
if padding_side == "left":
tensor[i, -len(item[key][0]):] = item[key][0].clone()
else:
tensor[i, : len(item[key][0])] = item[key][0].clone()
elif dim == 3:
if padding_side == "left":
tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
else:
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
return tensor
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
"""
def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.skip_batches = skip_batches
self.first_epoch = True
def __iter__(self):
for index, samples in enumerate(self.batch_sampler):
if index >= self.skip_batches and self.first_epoch:
yield samples
self.first_epoch = False
@property
def total_length(self):
return len(self.batch_sampler)
def __len__(self):
return len(self.batch_sampler) - self.skip_batches

View File

View File

@ -0,0 +1,112 @@
import torch
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
batch_size = logits.size()[0]
if top_p > 0.0:
logits = logits.view(batch_size, -1).contiguous()
for index in range(len(logits)):
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[index][indices_to_remove] = filter_value
logits = logits.view(batch_size, -1).contiguous()
return logits
def apply_repetition_penalty(
logits,
batch_size,
num_beams,
prev_output_tokens,
repetition_penalty,
start_idx=None,
end_idx=None,
window_size=None,
):
# only conduct repetition penalty for the output
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
for i in range(batch_size * num_beams):
if start_idx is None or end_idx is None:
output_tokens = prev_output_tokens[i].tolist()
else:
if end_idx >= start_idx:
if window_size:
output_tokens = prev_output_tokens[i][
max(start_idx, end_idx + 1 - window_size): end_idx + 1
].tolist()
else:
output_tokens = prev_output_tokens[i][start_idx: end_idx + 1].tolist()
else:
output_tokens = []
for previous_token in set(output_tokens):
# if score < 0 then repetition penalty has to
# multiplied to reduce the previous token probability
if logits[i, previous_token] < 0:
logits[i, previous_token] *= repetition_penalty
else:
logits[i, previous_token] /= repetition_penalty
class BeamHypotheses:
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_len = max_len
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.n_hyp = n_hyp
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.n_hyp:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.n_hyp:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty

Some files were not shown because too many files have changed in this diff Show More