Compare commits
7 Commits
Author | SHA1 | Date |
---|---|---|
carboncoo | c89395164e | |
carboncoo | 8e693d5876 | |
carboncoo | 415c624322 | |
carboncoo | 4139ba5dfe | |
carboncoo | a8d431c14f | |
carboncoo | 1857f60d1e | |
carboncoo | a041469104 |
|
@ -1,4 +0,0 @@
|
|||
# !/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024, QiYuan Inc
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = "<用户>"+data['input']+"<AI>"
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = data['input']
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
|
@ -1,134 +0,0 @@
|
|||
[
|
||||
{
|
||||
"dataset_name": "humanevallike_clean_dedup",
|
||||
"task_name": "humanevallike_clean_dedup",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 995339,
|
||||
"ave_tokens_per_line": 100,
|
||||
"total_tokens": 0.1
|
||||
},
|
||||
{
|
||||
"dataset_name": "leetcode_pass_code_0125",
|
||||
"task_name": "leetcode_pass_code_0125",
|
||||
"abs_weight": 0.006,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 10724,
|
||||
"ave_tokens_per_line": 200,
|
||||
"total_tokens": 0.002
|
||||
},
|
||||
{
|
||||
"dataset_name": "logiv2Annotate",
|
||||
"task_name": "logiv2Annotate",
|
||||
"abs_weight": 0.004,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 12566,
|
||||
"ave_tokens_per_line": 512,
|
||||
"total_tokens": 0.006
|
||||
},
|
||||
{
|
||||
"dataset_name": "mmlu_enhance",
|
||||
"task_name": "mmlu_enhance",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 169771,
|
||||
"ave_tokens_per_line": 300,
|
||||
"total_tokens": 0.05
|
||||
},
|
||||
{
|
||||
"dataset_name": "mtbench_like",
|
||||
"task_name": "mtbench_like",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 319080,
|
||||
"ave_tokens_per_line": 500,
|
||||
"total_tokens": 0.15
|
||||
},
|
||||
{
|
||||
"dataset_name": "ultra_dataset_new",
|
||||
"task_name": "ultra_dataset_new",
|
||||
"abs_weight": 2.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 385045,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 2.0
|
||||
},
|
||||
{
|
||||
"dataset_name": "sft_data_zh_wowru",
|
||||
"task_name": "sft_data_zh_wowru",
|
||||
"abs_weight": 1.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 1
|
||||
},
|
||||
{
|
||||
"dataset_name": "math_data",
|
||||
"task_name": "math_data",
|
||||
"abs_weight": 0.003,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 0.005
|
||||
},
|
||||
{
|
||||
"dataset_name": "t0",
|
||||
"task_name": "t0",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 1650309,
|
||||
"ave_tokens_per_line": 500.296266559615,
|
||||
"total_tokens": 0.82
|
||||
},
|
||||
{
|
||||
"dataset_name": "wikihow",
|
||||
"task_name": "wikihow",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 180128,
|
||||
"ave_tokens_per_line": 900.296266559615,
|
||||
"total_tokens": 0.16
|
||||
},
|
||||
{
|
||||
"dataset_name": "reclor",
|
||||
"task_name": "reclor",
|
||||
"abs_weight": 0.002,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 4174,
|
||||
"ave_tokens_per_line": 700.296266559615,
|
||||
"total_tokens": 0.003
|
||||
},
|
||||
{
|
||||
"dataset_name": "logic_test_lx_0127",
|
||||
"task_name": "logic_test_lx_0127",
|
||||
"abs_weight": 0.001,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2800,
|
||||
"ave_tokens_per_line": 200.96266559615,
|
||||
"total_tokens": 0.0004
|
||||
}
|
||||
]
|
|
@ -1,28 +0,0 @@
|
|||
{
|
||||
"vocab_size": 122753,
|
||||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"half_type": "bf16",
|
||||
"use_flash_attn": true,
|
||||
"flash_attn_mask_shape": "2d",
|
||||
"dim_model": 2304,
|
||||
"dim_ff": 5760,
|
||||
"dim_head": 64,
|
||||
"num_heads": 36,
|
||||
"num_kv_heads": 36,
|
||||
"num_layers": 40,
|
||||
"activate_fn": "silu",
|
||||
"init_std": 0.10,
|
||||
"scale": true,
|
||||
"scale_emb": 12,
|
||||
"scale_depth": 1.4,
|
||||
"dim_model_base": 256,
|
||||
"model_type": "fm9g",
|
||||
"architectures": [
|
||||
"FM9GForCausalLM"
|
||||
],
|
||||
"qk_norm": false,
|
||||
"tie_lm_head": true,
|
||||
"ffn_gated": true
|
||||
}
|
|
@ -1,548 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 QiYuan Inc.
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from bmtrain import nccl
|
||||
from bmtrain.global_var import config as bmt_config
|
||||
|
||||
sys.path.append("../../")
|
||||
from fm9g.arguments import get_args
|
||||
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
||||
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
|
||||
from fm9g.utils import exporter
|
||||
from fm9g.utils import logger
|
||||
from fm9g.utils.exporter import save_every_step_stats
|
||||
from fm9g.utils.training_stats import num_non_embedding_parameters
|
||||
from fm9g.utils.training_stats import num_parameters
|
||||
|
||||
|
||||
def get_tokenizer(args):
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
tokenizer = LlamaTokenizerFast(vocab_file=args.tokenizer_path)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(args):
|
||||
config = DragonflyConfig.from_json_file(args.model_config)
|
||||
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
||||
config.pose_prob = args.pose_prob
|
||||
config.pose_scaling_factor = args.pose_scaling_factor
|
||||
config.rope_scaling_type = args.rope_scaling_type
|
||||
config.rope_scaling_factor = args.rope_scaling_factor
|
||||
config.orig_max_length = args.orig_max_length
|
||||
|
||||
bmt.print_rank("model config: {}".format(config))
|
||||
bmt.print_rank("bmt config: {}".format(bmt.config))
|
||||
|
||||
model = Dragonfly(config)
|
||||
if args.load is not None:
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||
exporter.load_model_ckpt(args, model)
|
||||
else:
|
||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||
bmt.init_parameters(model)
|
||||
return model
|
||||
|
||||
|
||||
def get_optimizer(args, model):
|
||||
scale_lr_group = []
|
||||
normal_group = []
|
||||
scale_lr_group_name, normal_group_name = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
||||
scale_lr_group.append(p)
|
||||
scale_lr_group_name.append(n)
|
||||
else:
|
||||
normal_group.append(p)
|
||||
normal_group_name.append(n)
|
||||
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
||||
param_groups = [
|
||||
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
||||
{"params": normal_group, "lr": args.lr},
|
||||
]
|
||||
|
||||
if args.offload:
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
else:
|
||||
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
if args.load is not None and args.load_grad:
|
||||
exporter.load_optimizer_ckpt(args, optimizer)
|
||||
bmt.print_rank("optimizer is loaded!")
|
||||
return optimizer
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
from fm9g.training_utils.lr_scheduler import Cosine
|
||||
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
||||
|
||||
end_iter = args.train_iters
|
||||
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
||||
warmup_iters = int(end_iter * args.warmup_iters)
|
||||
else:
|
||||
warmup_iters = int(args.warmup_iters)
|
||||
|
||||
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
||||
drop_iters = int(end_iter * args.drop_iters)
|
||||
else:
|
||||
drop_iters = int(args.drop_iters)
|
||||
|
||||
if args.lr_scheduler == "cosine":
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
num_iter=args.start_step,
|
||||
#lr_end_restart=args.lr_end_restart,
|
||||
#resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
elif args.lr_scheduler == "warmupstabledrop":
|
||||
lr_scheduler = WarmupStableDrop(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
drop_iter=drop_iters,
|
||||
num_iter=args.start_step,
|
||||
resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
start = time.time()
|
||||
tokenizer = get_tokenizer(args)
|
||||
bmt.synchronize()
|
||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
model = get_model(args)
|
||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
optimizer = get_optimizer(args, model)
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def resume_training(args):
|
||||
ckpts = sorted(
|
||||
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
||||
)
|
||||
# find newest job
|
||||
ckpts = sorted(
|
||||
ckpts,
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
||||
)
|
||||
|
||||
if len(ckpts) > 0:
|
||||
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
||||
args.load = ckpts[0]
|
||||
# by default, do not load grad file
|
||||
args.load_grad = False
|
||||
args.start_step = 0
|
||||
else:
|
||||
# no ckpts, nothing we can do
|
||||
os._exit(1)
|
||||
|
||||
|
||||
def initialize():
|
||||
args = get_args(pretrain=True)
|
||||
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
||||
|
||||
if args.save is not None:
|
||||
os.makedirs(args.save, exist_ok=True)
|
||||
if args.load is not None:
|
||||
if args.only_load_model == 0:
|
||||
if args.start_step == 0:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
if "iteration" in log_ckpt:
|
||||
args.start_step = log_ckpt["iteration"]
|
||||
else:
|
||||
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
||||
logger.info("Start from step {}".format(args.start_step))
|
||||
elif args.only_load_model == 1:
|
||||
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
logger.info("You do not load model")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def see_memory(detail=False):
|
||||
if detail:
|
||||
res = torch.cuda.memory_summary()
|
||||
else:
|
||||
res = (
|
||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return res
|
||||
|
||||
|
||||
def add_mem_time(info, mem_usage, tim_usage):
|
||||
torch.cuda.synchronize()
|
||||
bmt.synchronize()
|
||||
mem_usage[info] = see_memory()
|
||||
tim_usage[info] = time.time()
|
||||
return mem_usage, tim_usage
|
||||
|
||||
|
||||
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
||||
# task_ids 可能有-1 来代表无效token
|
||||
_task_num = task_num + 1
|
||||
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
||||
# gen masks
|
||||
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
||||
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
||||
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
||||
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
||||
# calc loss and tokens
|
||||
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
# return token-wise avg losses and tokens
|
||||
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
||||
|
||||
|
||||
class ChunkAve:
|
||||
def __init__(self, chunk_size=100):
|
||||
self.ave_list = []
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def record(self, time):
|
||||
self.ave_list.append(time)
|
||||
self.ave_list = self.ave_list[-self.chunk_size :]
|
||||
|
||||
def get(self):
|
||||
return sum(self.ave_list) / len(self.ave_list)
|
||||
|
||||
|
||||
def pretrain(
|
||||
args,
|
||||
tokenizer,
|
||||
model: Dragonfly,
|
||||
optimizer,
|
||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||
):
|
||||
ave_model_time = ChunkAve(chunk_size=100)
|
||||
ave_iter_time = ChunkAve(chunk_size=100)
|
||||
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||
optim_manager = bmt.optim.OptimManager(
|
||||
loss_scale=None,
|
||||
loss_scale_steps=args.loss_scale_steps,
|
||||
loss_scale_factor=2,
|
||||
max_loss_scale=args.max_loss_scale,
|
||||
min_loss_scale=args.min_loss_scale,
|
||||
)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
|
||||
start_step = args.start_step
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
import distutils.version # noqa: F401
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
if not os.path.exists(args.tensorboard):
|
||||
os.makedirs(args.tensorboard)
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
|
||||
if args.load is not None:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
else:
|
||||
log_ckpt = {}
|
||||
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
||||
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
||||
|
||||
global_world_size = bmt.world_size()
|
||||
bmt.print_rank("Begin preparing dataset")
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset = MixedIndexedDataset(
|
||||
cfg_path=args.dataset,
|
||||
cfg_json_str=None,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
nthreads=args.dataloader_num_threads,
|
||||
prefetch_slice=args.dataloader_prefetch,
|
||||
weight_by_size=True,
|
||||
)
|
||||
|
||||
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
||||
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
||||
|
||||
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
batched_dataset,
|
||||
batch_size=None,
|
||||
collate_fn=lambda x: x,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
prefetch_factor=args.dataloader_prefetch_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
def dummy_generator():
|
||||
while True:
|
||||
yield None
|
||||
|
||||
mixed_indexed_dataset = dummy_generator()
|
||||
dataloader = mixed_indexed_dataset
|
||||
|
||||
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
||||
|
||||
bmt.print_rank("Preparing dataset done.")
|
||||
|
||||
# inspect at init
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
try:
|
||||
mem_usage, tim_usage = {}, {}
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
||||
|
||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||
|
||||
logits = model(
|
||||
input=data["inputs"],
|
||||
cu_seqlens=data["cu_seqlens"],
|
||||
max_seqlen=data["max_seqlen"],
|
||||
position_ids=data["position_ids"],
|
||||
)
|
||||
|
||||
# chunk targets and task_ids
|
||||
data["targets"] = (
|
||||
data["targets"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["targets"].shape[0], -1)
|
||||
)
|
||||
data["task_ids"] = (
|
||||
data["task_ids"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["task_ids"].shape[0], -1)
|
||||
)
|
||||
|
||||
_target = data["targets"].view(-1)
|
||||
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
||||
_w = (_target != -100).int()
|
||||
loss = non_reduced_loss.sum() / _w.sum().float()
|
||||
|
||||
global_loss = bmt.sum_loss(loss).item()
|
||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||
|
||||
optim_manager.backward(loss)
|
||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||
|
||||
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
||||
grad_accum_init_time = tim_usage["init"]
|
||||
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
optim_manager.zero_grad()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
model_time = tim_usage["optim"] - grad_accum_init_time
|
||||
ave_model_time.record(model_time)
|
||||
else:
|
||||
# dummy optim step
|
||||
grad_norm = torch.Tensor([0.0]).cuda()
|
||||
tim_usage["optim"] = tim_usage["backward"]
|
||||
mem_usage["optim"] = mem_usage["backward"]
|
||||
|
||||
with torch.no_grad():
|
||||
task_num = len(data["task_names"])
|
||||
task_loss, task_token = get_task_loss_and_token(
|
||||
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
||||
)
|
||||
task_loss_map: Dict[str, float] = {}
|
||||
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
||||
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
||||
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
||||
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
||||
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
||||
ave_task_loss = sum_task_loss / tot_task_token
|
||||
for i in range(task_num):
|
||||
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
||||
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
||||
|
||||
local_total_rate = torch.Tensor([data["lengths"].float().mean() / args.max_length]).cuda()
|
||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||
global_token_pass += (
|
||||
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
)
|
||||
last_before_log_time = tim_usage["before_log"]
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
iter_time = tim_usage["before_log"] - last_before_log_time
|
||||
|
||||
ave_iter_time.record(iter_time)
|
||||
|
||||
train_info = {
|
||||
"time": iter_time,
|
||||
"iteration": iteration,
|
||||
"loss": global_loss,
|
||||
"lr": lr_scheduler.current_lr,
|
||||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"task_loss": task_loss_map,
|
||||
"total_task_token": global_total_task_token,
|
||||
}
|
||||
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
||||
|
||||
bmt.print_rank(
|
||||
(
|
||||
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time:.2f} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
||||
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
||||
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
||||
).format(
|
||||
iteration=iteration,
|
||||
loss=global_loss,
|
||||
lr=lr_scheduler.current_lr,
|
||||
model_time=model_time,
|
||||
iter_time=iter_time,
|
||||
chunk_ave_time=ave_iter_time.get(),
|
||||
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
||||
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
||||
grad_norm=grad_norm.item(),
|
||||
global_token_pass=global_token_pass_str,
|
||||
mem_usage=max([value for key, value in mem_usage.items()]),
|
||||
)
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"task_loss:\t| "
|
||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % 10 == 0:
|
||||
bmt.print_rank(
|
||||
"task_tokens (B):\t| "
|
||||
+ " | ".join(
|
||||
[
|
||||
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
||||
for task_name, task_token in global_total_task_token.items()
|
||||
]
|
||||
)
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % args.inspect_iters == 0:
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
if args.log_dir is not None and bmt.rank() == 0:
|
||||
if args.save is not None:
|
||||
save_every_step_stats(train_info, args.save)
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||
for task_name, loss in task_loss_map.items():
|
||||
if not math.isnan(loss):
|
||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||
|
||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||
log_ckpt = {
|
||||
"global_total_task_token": global_total_task_token,
|
||||
"global_token_pass": global_token_pass,
|
||||
"iteration": iteration,
|
||||
}
|
||||
|
||||
if args.save is not None and iteration % args.save_iters == 0:
|
||||
exporter.export(
|
||||
model,
|
||||
mixed_indexed_dataset,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
iteration,
|
||||
args,
|
||||
log_ckpt=log_ckpt,
|
||||
final_save=False,
|
||||
)
|
||||
|
||||
if iteration == args.train_iters and args.stop_when_end == 1:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"train loop err: {e}")
|
||||
raise e
|
||||
finally:
|
||||
pass
|
||||
|
||||
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
||||
|
||||
|
||||
def convert_to_k_and_b(number):
|
||||
if number >= 1e9: # 大于或等于10亿
|
||||
b_number = number / 1e9
|
||||
return f"{b_number:.2f}B"
|
||||
elif number >= 1e6: # 大于或等于1百万
|
||||
k_number = number / 1e6
|
||||
return f"{k_number:.2f}M"
|
||||
elif number >= 1e3:
|
||||
k_number = number / 1e3
|
||||
return f"{k_number:.2f}K"
|
||||
else:
|
||||
return str(number)
|
||||
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
bmt.synchronize()
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
bmt.print_rank("finish loading")
|
||||
bmt.print_rank(
|
||||
"Number of parameter {}, Number of non-e parameter {}".format(
|
||||
num_parameters(model), num_non_embedding_parameters(model)
|
||||
)
|
||||
)
|
||||
bmt.print_rank("args: {}".format(args))
|
||||
|
||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,234 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
#export OMP_NUM_THREADS=16
|
||||
|
||||
declare -A args # Declare an associative array to store arguments and values
|
||||
|
||||
args["model_unique"]="2b_0701"
|
||||
args["resume_ckpt"]=""
|
||||
args["config"]="2.4b"
|
||||
args["flash"]="cuda"
|
||||
args["batch_size"]="1"
|
||||
args["max_length"]="4096"
|
||||
args["save_iters"]="500"
|
||||
args["train_iters"]="10"
|
||||
args["dataset_config"]="fm9g_sft"
|
||||
args["local"]="False"
|
||||
args["dataloader"]="indexed"
|
||||
args["save"]="True"
|
||||
args["dataloader_num_threads"]=1
|
||||
args["dataloader_prefetch"]=1
|
||||
args["dataloader_prefetch_factor"]=1
|
||||
args["dataloader_num_workers"]=1
|
||||
args["lr"]="1e-5"
|
||||
args["warmup_iters"]="20"
|
||||
args["drop_iters"]="0.1"
|
||||
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
||||
args["load_grad"]="False"
|
||||
args["grad_ckpt_num"]="160"
|
||||
args["exp_group"]=""
|
||||
args["ignore_cuda_oom"]="1"
|
||||
args["tensorboard_all_tasks"]="0"
|
||||
args["stop_when_end"]="0"
|
||||
args["only_run_dataloader"]="0"
|
||||
args["eps"]="1e-6"
|
||||
args["inspect_iters"]="100"
|
||||
args["strict_state_dict"]="1"
|
||||
args["only_load_model"]="1"
|
||||
args["lr_scheduler"]="cosine"
|
||||
args["resume_no_optimze"]="0"
|
||||
args["tp_size"]="1"
|
||||
args["parallel_load_datastate"]="8"
|
||||
args["async_save"]="False"
|
||||
args["load_dataloader_ckpt"]="0"
|
||||
args["drop_begin"]="-1"
|
||||
args["drop_rate"]="0.5"
|
||||
args["use_checkpoint"]="0"
|
||||
|
||||
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
||||
while read -r key value; do
|
||||
args["$key"]="$value"
|
||||
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
||||
|
||||
|
||||
|
||||
# 用cmd arg 再更新一次
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Print the values of the arguments
|
||||
echo "----------- CMD args ----------"
|
||||
for key in "${!args[@]}"; do
|
||||
echo "$key: ${args[$key]}"
|
||||
done
|
||||
echo "--------- END CMD args --------"
|
||||
|
||||
|
||||
if [[ ${args["flash"]} == "triton" ]]; then
|
||||
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
||||
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
||||
echo "triton flash"
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
||||
# GPUS_PER_NODE=1
|
||||
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
||||
|
||||
|
||||
if [[ ${args["model_unique"]} == "" ]]; then
|
||||
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
||||
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
||||
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
||||
else
|
||||
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
||||
fi
|
||||
echo "model_unique: "$MODEL_UNIQUE
|
||||
|
||||
# --------------- 运行参数 ---------------
|
||||
|
||||
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
||||
OPTS+=" --batch-size ${args["batch_size"]}"
|
||||
OPTS+=" --train-iters ${args["train_iters"]}"
|
||||
OPTS+=" --save-iters ${args["save_iters"]}"
|
||||
OPTS+=" --save-name fm9g_live_checkpoint"
|
||||
OPTS+=" --max-length ${args["max_length"]}"
|
||||
OPTS+=" --lr ${args["lr"]}"
|
||||
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
||||
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
||||
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
||||
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
||||
OPTS+=" --offload"
|
||||
#OPTS+=" --vocab ./tokenizer/vocab.txt"
|
||||
OPTS+=" --flash ${args["flash"]}"
|
||||
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
||||
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
||||
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
||||
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
||||
OPTS+=" --eps ${args["eps"]}"
|
||||
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
||||
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
||||
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
||||
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --tp-size ${args["tp_size"]}"
|
||||
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
||||
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
||||
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
||||
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
||||
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
||||
|
||||
if [[ ${args["load_grad"]} == "True" ]]; then
|
||||
OPTS+=" --load-grad"
|
||||
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["async_save"]} == "True" ]]; then
|
||||
OPTS+=" --async_save"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
||||
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
||||
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
||||
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
||||
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
||||
fi
|
||||
|
||||
|
||||
# --------------- 写文件路径 ---------------
|
||||
## checkpoint
|
||||
if [[ ${args["save"]} == "True" ]]; then
|
||||
|
||||
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
||||
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
||||
else
|
||||
echo "won't save model"
|
||||
fi
|
||||
|
||||
|
||||
## logs,/local/logs 等价于 ./datalogs(软链)
|
||||
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
||||
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
||||
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
||||
|
||||
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
else
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
||||
fi
|
||||
|
||||
|
||||
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
||||
if [ "${args["resume_ckpt"]}" != "" ]; then
|
||||
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
||||
else
|
||||
echo "No checkpoint to load"
|
||||
fi
|
||||
|
||||
|
||||
filename="pretrain_dragonfly"
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
else
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
fi
|
||||
|
||||
|
||||
GPUS_PER_NODE=1
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ENDPOINT=g3006
|
||||
MASTER_PORT=23456
|
||||
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
|
||||
echo "-------final CMD is------"
|
||||
echo "${CMD}"
|
||||
echo "-------final CMD end------"
|
||||
|
||||
$CMD
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
@ -1,9 +0,0 @@
|
|||
{
|
||||
"pretrain": {
|
||||
"train_iters": 1000000000,
|
||||
"batch_size": 1,
|
||||
"max_length": 4096,
|
||||
"n_gpus": 8,
|
||||
"lr": 0.01
|
||||
}
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = "<用户>"+data['input']+"<AI>"
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,20 +0,0 @@
|
|||
|
||||
import random
|
||||
|
||||
|
||||
def rand(n: int, r: random.Random):
|
||||
return int(r.random() * n)
|
||||
|
||||
def transform(data, num_sample: int, r: random.Random):
|
||||
if 'input' in data:
|
||||
_input = data['input']
|
||||
else:
|
||||
_input = ""
|
||||
|
||||
if 'output' in data:
|
||||
_output = data['output']
|
||||
else:
|
||||
_output = ""
|
||||
return {"input": _input,
|
||||
"output": _output,
|
||||
}
|
|
@ -1,134 +0,0 @@
|
|||
[
|
||||
{
|
||||
"dataset_name": "humanevallike_clean_dedup",
|
||||
"task_name": "humanevallike_clean_dedup",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/humanevallike_clean_dedup",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 995339,
|
||||
"ave_tokens_per_line": 100,
|
||||
"total_tokens": 0.1
|
||||
},
|
||||
{
|
||||
"dataset_name": "leetcode_pass_code_0125",
|
||||
"task_name": "leetcode_pass_code_0125",
|
||||
"abs_weight": 0.006,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/leetcode_pass_code_0125",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 10724,
|
||||
"ave_tokens_per_line": 200,
|
||||
"total_tokens": 0.002
|
||||
},
|
||||
{
|
||||
"dataset_name": "logiv2Annotate",
|
||||
"task_name": "logiv2Annotate",
|
||||
"abs_weight": 0.004,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logiv2Annotate",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 12566,
|
||||
"ave_tokens_per_line": 512,
|
||||
"total_tokens": 0.006
|
||||
},
|
||||
{
|
||||
"dataset_name": "mmlu_enhance",
|
||||
"task_name": "mmlu_enhance",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mmlu_enhance",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 169771,
|
||||
"ave_tokens_per_line": 300,
|
||||
"total_tokens": 0.05
|
||||
},
|
||||
{
|
||||
"dataset_name": "mtbench_like",
|
||||
"task_name": "mtbench_like",
|
||||
"abs_weight": 0.2,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/mtbench_like",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 319080,
|
||||
"ave_tokens_per_line": 500,
|
||||
"total_tokens": 0.15
|
||||
},
|
||||
{
|
||||
"dataset_name": "ultra_dataset_new",
|
||||
"task_name": "ultra_dataset_new",
|
||||
"abs_weight": 2.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/ultra_dataset_new",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 385045,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 2.0
|
||||
},
|
||||
{
|
||||
"dataset_name": "sft_data_zh_wowru",
|
||||
"task_name": "sft_data_zh_wowru",
|
||||
"abs_weight": 1.0,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/sft_data_zh_wowru",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 1
|
||||
},
|
||||
{
|
||||
"dataset_name": "math_data",
|
||||
"task_name": "math_data",
|
||||
"abs_weight": 0.003,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/math_data",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2963260,
|
||||
"ave_tokens_per_line": 200.296266559615,
|
||||
"total_tokens": 0.005
|
||||
},
|
||||
{
|
||||
"dataset_name": "t0",
|
||||
"task_name": "t0",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/t0",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 1650309,
|
||||
"ave_tokens_per_line": 500.296266559615,
|
||||
"total_tokens": 0.82
|
||||
},
|
||||
{
|
||||
"dataset_name": "wikihow",
|
||||
"task_name": "wikihow",
|
||||
"abs_weight": 0.1,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/wikihow",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 180128,
|
||||
"ave_tokens_per_line": 900.296266559615,
|
||||
"total_tokens": 0.16
|
||||
},
|
||||
{
|
||||
"dataset_name": "reclor",
|
||||
"task_name": "reclor",
|
||||
"abs_weight": 0.002,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/reclor",
|
||||
"transforms": "0124_hq_data/general/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 4174,
|
||||
"ave_tokens_per_line": 700.296266559615,
|
||||
"total_tokens": 0.003
|
||||
},
|
||||
{
|
||||
"dataset_name": "logic_test_lx_0127",
|
||||
"task_name": "logic_test_lx_0127",
|
||||
"abs_weight": 0.001,
|
||||
"path": "/data/groups/QY_LLM_Core/sa_data/sft_data/0124_hq_data/logic_test_lx_0127",
|
||||
"transforms": "0124_hq_data/add_userai/script_cpmc.py",
|
||||
"allow_repeat": true,
|
||||
"nlines": 2800,
|
||||
"ave_tokens_per_line": 200.96266559615,
|
||||
"total_tokens": 0.0004
|
||||
}
|
||||
]
|
|
@ -1,568 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 ModelBest Inc.
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from bmtrain import nccl
|
||||
from bmtrain.global_var import config as bmt_config
|
||||
|
||||
sys.path.append("../../")
|
||||
from fm9g.arguments import get_args
|
||||
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
||||
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import CudaPrefetcher
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import MixedIndexedDataset
|
||||
from fm9g.dragonfly.training_tasks.pretrain_indexed_9g import UnpadBatchedMixedDataset
|
||||
from fm9g.utils import exporter
|
||||
from fm9g.utils import logger
|
||||
from fm9g.utils.exporter import save_every_step_stats
|
||||
from fm9g.utils.training_stats import num_non_embedding_parameters
|
||||
from fm9g.utils.training_stats import num_parameters
|
||||
|
||||
|
||||
def get_tokenizer(args):
|
||||
from fm9g.tokenizer import FM9GTokenizer
|
||||
tokenizer = FM9GTokenizer(path=args.vocab)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(args):
|
||||
config = DragonflyConfig.from_json_file(args.model_config)
|
||||
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
||||
config.pose_prob = args.pose_prob
|
||||
config.pose_scaling_factor = args.pose_scaling_factor
|
||||
config.rope_scaling_type = args.rope_scaling_type
|
||||
config.rope_scaling_factor = args.rope_scaling_factor
|
||||
config.orig_max_length = args.orig_max_length
|
||||
config.use_checkpoint = True if args.use_checkpoint == 1 else False
|
||||
|
||||
bmt.print_rank("model config: {}".format(config))
|
||||
bmt.print_rank("bmt config: {}".format(bmt.config))
|
||||
|
||||
model = Dragonfly(config)
|
||||
if args.load is not None:
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
||||
exporter.load_model_ckpt(args, model)
|
||||
else:
|
||||
bmt.print_rank("args.load is None, start to initialize parameters")
|
||||
bmt.init_parameters(model)
|
||||
return model
|
||||
|
||||
|
||||
def get_optimizer(args, model):
|
||||
scale_lr_group = []
|
||||
normal_group = []
|
||||
scale_lr_group_name, normal_group_name = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if n.endswith(".weight") and "layernorm" not in n and "embedding" not in n and "lm_head" not in n:
|
||||
scale_lr_group.append(p)
|
||||
scale_lr_group_name.append(n)
|
||||
else:
|
||||
normal_group.append(p)
|
||||
normal_group_name.append(n)
|
||||
bmt.print_rank(scale_lr_group_name, normal_group_name)
|
||||
param_groups = [
|
||||
{"params": scale_lr_group, "lr": args.lr / model.config.scale_width},
|
||||
{"params": normal_group, "lr": args.lr},
|
||||
]
|
||||
|
||||
if args.offload:
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
else:
|
||||
optimizer = bmt.optim.AdamOptimizer(param_groups, betas=(0.9, 0.95), weight_decay=args.weight_decay)
|
||||
if args.load is not None and args.load_grad:
|
||||
exporter.load_optimizer_ckpt(args, optimizer)
|
||||
bmt.print_rank("optimizer is loaded!")
|
||||
return optimizer
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
from fm9g.training_utils.lr_scheduler import Cosine
|
||||
from fm9g.training_utils.lr_scheduler import WarmupStableDrop
|
||||
from fm9g.training_utils.lr_scheduler import WarmupStableExp
|
||||
|
||||
end_iter = args.train_iters
|
||||
if 0 < args.warmup_iters < 1: # 需要支持按固定比例step用来做warmup的
|
||||
warmup_iters = int(end_iter * args.warmup_iters)
|
||||
else:
|
||||
warmup_iters = int(args.warmup_iters)
|
||||
|
||||
if 0 < args.drop_iters < 1: # 需要支持按固定比例step用来做drop的
|
||||
drop_iters = int(end_iter * args.drop_iters)
|
||||
else:
|
||||
drop_iters = int(args.drop_iters)
|
||||
|
||||
if args.lr_scheduler == "cosine":
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
# lr_end_restart=args.lr_end_restart,
|
||||
# resume_no_optimze=args.resume_no_optimze,
|
||||
#)
|
||||
elif args.lr_scheduler == "warmupstabledrop":
|
||||
lr_scheduler = WarmupStableDrop(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
end_iter=end_iter, # 原来是lr_decay_iter
|
||||
drop_iter=drop_iters,
|
||||
num_iter=args.start_step,
|
||||
resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
elif args.lr_scheduler == "warmupstableexp":
|
||||
lr_scheduler = WarmupStableExp(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=warmup_iters,
|
||||
drop_begin=args.drop_begin, # 原来是lr_decay_iter
|
||||
drop_iter=drop_iters,
|
||||
drop_rate=args.drop_rate,
|
||||
num_iter=args.start_step,
|
||||
resume_no_optimze=args.resume_no_optimze,
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
start = time.time()
|
||||
tokenizer = get_tokenizer(args)
|
||||
bmt.synchronize()
|
||||
logger.info("load tokenizer in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
model = get_model(args)
|
||||
logger.info("load model in {:.2f}s".format(time.time() - start))
|
||||
|
||||
start = time.time()
|
||||
optimizer = get_optimizer(args, model)
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
logger.info("load lr_scheduler in {:.2f}s".format(time.time() - start))
|
||||
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def resume_training(args):
|
||||
ckpts = sorted(
|
||||
[z for z in chain(*[[os.path.join(x[0], y) for y in x[2]] for x in os.walk(args.save)]) if z.endswith(".pt")],
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("(\d+).pt", x)[1]),
|
||||
)
|
||||
# find newest job
|
||||
ckpts = sorted(
|
||||
ckpts,
|
||||
reverse=True,
|
||||
key=lambda x: (int)(re.search("job_(\d+)_ckpt", x)[1]),
|
||||
)
|
||||
|
||||
if len(ckpts) > 0:
|
||||
bmt.print_rank(f"resuming with last checkpoint: {ckpts[0]}")
|
||||
args.load = ckpts[0]
|
||||
# by default, do not load grad file
|
||||
args.load_grad = False
|
||||
args.start_step = 0
|
||||
else:
|
||||
# no ckpts, nothing we can do
|
||||
os._exit(1)
|
||||
|
||||
|
||||
def initialize():
|
||||
args = get_args(pretrain=True)
|
||||
bmt.init_distributed(seed=args.seed, tp_size=args.tp_size)
|
||||
|
||||
if args.save is not None:
|
||||
os.makedirs(args.save, exist_ok=True)
|
||||
if args.load is not None:
|
||||
if args.only_load_model == 0:
|
||||
if args.start_step == 0:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
if "iteration" in log_ckpt:
|
||||
args.start_step = log_ckpt["iteration"]
|
||||
else:
|
||||
args.start_step = (int)(re.findall("(\d+)", args.load)[-1])
|
||||
logger.info("Start from step {}".format(args.start_step))
|
||||
elif args.only_load_model == 1:
|
||||
logger.info("You load model ckpt, and choose to completely start the 0 step.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
logger.info("You do not load model")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def see_memory(detail=False):
|
||||
if detail:
|
||||
res = torch.cuda.memory_summary()
|
||||
else:
|
||||
res = (
|
||||
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024), 2),
|
||||
)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return res
|
||||
|
||||
|
||||
def add_mem_time(info, mem_usage, tim_usage):
|
||||
torch.cuda.synchronize()
|
||||
bmt.synchronize()
|
||||
mem_usage[info] = see_memory()
|
||||
tim_usage[info] = time.time()
|
||||
return mem_usage, tim_usage
|
||||
|
||||
|
||||
def get_task_loss_and_token(loss, task_ids, task_num, targets):
|
||||
# task_ids 可能有-1 来代表无效token
|
||||
_task_num = task_num + 1
|
||||
_task_ids = (task_ids.clone() + 1).to(torch.int64) # [batch_size, seq_len]
|
||||
# gen masks
|
||||
_task_mask = torch.zeros((_task_num, *_task_ids.shape), device=_task_ids.device)
|
||||
_task_mask.scatter_(0, _task_ids.unsqueeze(0), 1) # [task_num, batch_size, seq_len]
|
||||
_loss_mask = torch.ne(targets, -100).to(torch.int32)
|
||||
_mask = _task_mask * _loss_mask.unsqueeze(0) # [task_num, batch_size, seq_len]
|
||||
# calc loss and tokens
|
||||
_task_losses = (loss.unsqueeze(0) * _mask).view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
_task_tokens = _mask.view((_task_num, -1)).sum(dim=-1)[1:] # [task_num]
|
||||
# return token-wise avg losses and tokens
|
||||
return torch.nan_to_num(_task_losses / _task_tokens, nan=0.0), _task_tokens
|
||||
|
||||
|
||||
class ChunkAve:
|
||||
def __init__(self, chunk_size=100):
|
||||
self.ave_list = []
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def record(self, time):
|
||||
self.ave_list.append(time)
|
||||
self.ave_list = self.ave_list[-self.chunk_size :]
|
||||
|
||||
def get(self):
|
||||
return sum(self.ave_list) / len(self.ave_list)
|
||||
|
||||
|
||||
def pretrain(
|
||||
args,
|
||||
tokenizer,
|
||||
model: Dragonfly,
|
||||
optimizer,
|
||||
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
|
||||
):
|
||||
ave_model_time = ChunkAve(chunk_size=100)
|
||||
ave_iter_time = ChunkAve(chunk_size=100)
|
||||
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||
optim_manager = bmt.optim.OptimManager(
|
||||
loss_scale=bmt.world_size(),
|
||||
loss_scale_steps=args.loss_scale_steps,
|
||||
loss_scale_factor=2,
|
||||
max_loss_scale=bmt.world_size(),
|
||||
min_loss_scale=bmt.world_size(),
|
||||
)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
|
||||
start_step = args.start_step
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
import distutils.version # noqa: F401
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
if not os.path.exists(args.tensorboard):
|
||||
os.makedirs(args.tensorboard)
|
||||
writer = SummaryWriter(log_dir=args.tensorboard)
|
||||
|
||||
if args.load is not None:
|
||||
log_ckpt = exporter.load_log_ckpt(args)
|
||||
else:
|
||||
log_ckpt = {}
|
||||
|
||||
global_token_pass = log_ckpt.get("global_token_pass", 0.0)
|
||||
global_total_task_token = defaultdict(int, log_ckpt.get("global_total_task_token", {})) # token by task
|
||||
|
||||
global_world_size = bmt.world_size()
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset = MixedIndexedDataset(
|
||||
cfg_path=args.dataset,
|
||||
cfg_json_str=None,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
nthreads=args.dataloader_num_threads,
|
||||
prefetch_slice=args.dataloader_prefetch,
|
||||
weight_by_size=True,
|
||||
)
|
||||
|
||||
if args.load is not None and args.only_load_model == 0 and args.load_dataloader_ckpt == 1:
|
||||
exporter.load_dataloader_ckpt(args, mixed_indexed_dataset)
|
||||
|
||||
batched_dataset = UnpadBatchedMixedDataset(mixed_indexed_dataset, args.batch_size, args.max_length)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
batched_dataset,
|
||||
batch_size=None,
|
||||
collate_fn=lambda x: x,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
prefetch_factor=args.dataloader_prefetch_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
def dummy_generator():
|
||||
while True:
|
||||
yield None
|
||||
|
||||
mixed_indexed_dataset = dummy_generator()
|
||||
dataloader = mixed_indexed_dataset
|
||||
|
||||
DataIterator = CudaPrefetcher(dataloader, tp_size=args.tp_size, tp_rank=bmt.config["tp_rank"])
|
||||
|
||||
bmt.print_rank("Preparing dataset done.")
|
||||
|
||||
# inspect at init
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
try:
|
||||
mem_usage, tim_usage = {}, {}
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
for iteration, data in enumerate(DataIterator, start=start_step + 1):
|
||||
if args.tp_size == 1 or bmt.config["tp_rank"] == 0:
|
||||
mixed_indexed_dataset.update_states(data["task_ids"], data["indexes"])
|
||||
|
||||
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
|
||||
|
||||
logits = model(
|
||||
input=data["inputs"],
|
||||
cu_seqlens=data["cu_seqlens"],
|
||||
max_seqlen=data["max_seqlen"],
|
||||
position_ids=data["position_ids"],
|
||||
)
|
||||
#print("logits: ", logits)
|
||||
|
||||
# chunk targets and task_ids
|
||||
data["targets"] = (
|
||||
data["targets"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["targets"].shape[0], -1)
|
||||
)
|
||||
data["task_ids"] = (
|
||||
data["task_ids"]
|
||||
.view(-1)
|
||||
.chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]]
|
||||
.view(data["task_ids"].shape[0], -1)
|
||||
)
|
||||
|
||||
_target = data["targets"].view(-1)
|
||||
non_reduced_loss = loss_func(logits.view(-1, logits.size(-1)), _target)
|
||||
_w = (_target != -100).int()
|
||||
loss = non_reduced_loss.sum() / _w.sum().float()
|
||||
|
||||
global_loss = bmt.sum_loss(loss).item()
|
||||
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
|
||||
|
||||
optim_manager.backward(loss)
|
||||
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
|
||||
|
||||
if iteration % args.grad_accum == 0 or iteration == args.train_iters:
|
||||
grad_accum_init_time = tim_usage["init"]
|
||||
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
optim_manager.zero_grad()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
model_time = tim_usage["optim"] - grad_accum_init_time
|
||||
ave_model_time.record(model_time)
|
||||
else:
|
||||
# dummy optim step
|
||||
grad_norm = torch.Tensor([0.0]).cuda()
|
||||
tim_usage["optim"] = tim_usage["backward"]
|
||||
mem_usage["optim"] = mem_usage["backward"]
|
||||
model_time = tim_usage["optim"] - tim_usage['init']
|
||||
|
||||
with torch.no_grad():
|
||||
task_num = len(data["task_names"])
|
||||
task_loss, task_token = get_task_loss_and_token(
|
||||
non_reduced_loss, data["task_ids"], task_num, data["targets"]
|
||||
)
|
||||
task_loss_map: Dict[str, float] = {}
|
||||
gatherd_task_loss_map = bmt.distributed.all_gather(task_loss)
|
||||
gatherd_task_token_map = bmt.distributed.all_gather(task_token)
|
||||
gatherd_task_loss_token_map = gatherd_task_loss_map * gatherd_task_token_map
|
||||
sum_task_loss = gatherd_task_loss_token_map.sum(dim=0)
|
||||
tot_task_token = gatherd_task_token_map.sum(dim=0)
|
||||
ave_task_loss = sum_task_loss / tot_task_token
|
||||
for i in range(task_num):
|
||||
task_loss_map[data["task_names"][i]] = ave_task_loss[i].item()
|
||||
global_total_task_token[data["task_names"][i]] += tot_task_token[i].item()
|
||||
|
||||
local_total_rate = torch.Tensor(
|
||||
[data["lengths"].float().mean() / (args.max_length * args.batch_size)]
|
||||
).cuda()
|
||||
local_total_rate = bmt.sum_loss(local_total_rate).item()
|
||||
global_token_pass += (
|
||||
(global_world_size // args.tp_size) * local_total_rate * args.max_length * args.batch_size
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"=========================================" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
)
|
||||
last_before_log_time = tim_usage["before_log"]
|
||||
mem_usage, tim_usage = add_mem_time("before_log", mem_usage, tim_usage)
|
||||
|
||||
iter_time = tim_usage["before_log"] - last_before_log_time
|
||||
|
||||
ave_iter_time.record(iter_time)
|
||||
|
||||
train_info = {
|
||||
"time": iter_time,
|
||||
"iteration": iteration,
|
||||
"loss": global_loss,
|
||||
"lr": lr_scheduler.current_lr,
|
||||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / ave_iter_time.get() / args.tp_size,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"mask_max": ((data["targets"] >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"task_loss": task_loss_map,
|
||||
"total_task_token": global_total_task_token,
|
||||
}
|
||||
global_token_pass_str = convert_to_k_and_b(global_token_pass)
|
||||
|
||||
time_report_str = "{model_time:.2f}={forward_time:.2f}+{backward_time:.2f}+{optim_time:.2f}".format(model_time=model_time, forward_time=tim_usage['forward']-tim_usage['init'], backward_time=tim_usage['backward']-tim_usage['forward'], optim_time=tim_usage['optim'] - tim_usage['backward'])
|
||||
bmt.print_rank(
|
||||
(
|
||||
"| Iter: {iteration:6d} | loss: {loss:.4f} | lr: {lr:.4e} | model_time: {model_time} | iter_time: {iter_time:.2f}| chunk_ave_time: {chunk_ave_time:.2f}"
|
||||
+ " token/max: {tokenrate:.4f} | mask/max: {maskrate:.4f} | grad_norm: {grad_norm:.4f} | global_token_pass (B):"
|
||||
+ "{global_token_pass} | mem_usage {mem_usage} | "
|
||||
).format(
|
||||
iteration=iteration,
|
||||
loss=global_loss,
|
||||
lr=lr_scheduler.current_lr,
|
||||
model_time=time_report_str,
|
||||
iter_time=iter_time,
|
||||
chunk_ave_time=ave_iter_time.get(),
|
||||
tokenrate=data["lengths"].float().mean() / args.max_length / args.batch_size,
|
||||
maskrate=(data["targets"] >= 0).sum(-1).float().mean() / args.max_length / args.batch_size,
|
||||
grad_norm=grad_norm.item(),
|
||||
global_token_pass=global_token_pass_str,
|
||||
mem_usage=max([value for key, value in mem_usage.items()]),
|
||||
)
|
||||
)
|
||||
|
||||
bmt.print_rank(
|
||||
"task_loss:\t| "
|
||||
+ " | ".join(["{}: {:.4f}".format(task_name, loss) for task_name, loss in task_loss_map.items()])
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % 10 == 0:
|
||||
bmt.print_rank(
|
||||
"task_tokens (B):\t| "
|
||||
+ " | ".join(
|
||||
[
|
||||
"{}: {:.4f}".format(task_name, task_token / 10**9)
|
||||
for task_name, task_token in global_total_task_token.items()
|
||||
]
|
||||
)
|
||||
+ " |"
|
||||
)
|
||||
|
||||
if iteration % args.inspect_iters == 0:
|
||||
model_inspect = bmt.inspect.inspect_model(model, "*")
|
||||
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
|
||||
|
||||
if args.log_dir is not None and bmt.rank() == 0:
|
||||
if args.save is not None:
|
||||
save_every_step_stats(train_info, args.save)
|
||||
|
||||
if args.tensorboard is not None and bmt.rank() == 0:
|
||||
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||
for task_name, loss in task_loss_map.items():
|
||||
if not math.isnan(loss):
|
||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||
|
||||
# -------- save file. If need to backup by Klara platform, use export.xx_save --------
|
||||
log_ckpt = {
|
||||
"global_total_task_token": global_total_task_token,
|
||||
"global_token_pass": global_token_pass,
|
||||
"iteration": iteration,
|
||||
}
|
||||
|
||||
if args.save is not None and iteration % args.save_iters == 0:
|
||||
exporter.export(
|
||||
model,
|
||||
mixed_indexed_dataset,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
iteration,
|
||||
args,
|
||||
log_ckpt=log_ckpt,
|
||||
final_save=False,
|
||||
async_save=args.async_save,
|
||||
)
|
||||
|
||||
if iteration == args.train_iters and args.stop_when_end == 1:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"train loop err: {e}")
|
||||
raise e
|
||||
finally:
|
||||
pass
|
||||
|
||||
exporter.export(model, mixed_indexed_dataset, tokenizer, optimizer, -1, args, final_save=False)
|
||||
|
||||
|
||||
def convert_to_k_and_b(number):
|
||||
if number >= 1e9: # 大于或等于10亿
|
||||
b_number = number / 1e9
|
||||
return f"{b_number:.2f}B"
|
||||
elif number >= 1e6: # 大于或等于1百万
|
||||
k_number = number / 1e6
|
||||
return f"{k_number:.2f}M"
|
||||
elif number >= 1e3:
|
||||
k_number = number / 1e3
|
||||
return f"{k_number:.2f}K"
|
||||
else:
|
||||
return str(number)
|
||||
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
bmt.synchronize()
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
bmt.print_rank("finish loading")
|
||||
bmt.print_rank(
|
||||
"Number of parameter {}, Number of non-e parameter {}".format(
|
||||
num_parameters(model), num_non_embedding_parameters(model)
|
||||
)
|
||||
)
|
||||
bmt.print_rank("args: {}".format(args))
|
||||
|
||||
print("begining training")
|
||||
pretrain(args, tokenizer, model, optimizer, lr_scheduler)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,234 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
#export OMP_NUM_THREADS=16
|
||||
|
||||
declare -A args # Declare an associative array to store arguments and values
|
||||
|
||||
args["model_unique"]="8b_0702"
|
||||
args["resume_ckpt"]=""
|
||||
args["config"]="8b"
|
||||
args["flash"]="cuda"
|
||||
args["batch_size"]="1"
|
||||
args["max_length"]="4096"
|
||||
args["save_iters"]="500"
|
||||
args["train_iters"]="10"
|
||||
args["dataset_config"]="fm9g_sft"
|
||||
args["local"]="False"
|
||||
args["dataloader"]="indexed"
|
||||
args["save"]="True"
|
||||
args["dataloader_num_threads"]=1
|
||||
args["dataloader_prefetch"]=2
|
||||
args["dataloader_prefetch_factor"]=32
|
||||
args["dataloader_num_workers"]=2
|
||||
args["lr"]="1e-5"
|
||||
args["warmup_iters"]="20"
|
||||
args["drop_iters"]="0.1"
|
||||
args["tokenizer_path"]="./tokenizer/tokenizer.model" # /user/tc_agi/klara/baichuan2/baichuan2.tokenizer.model
|
||||
args["load_grad"]="False"
|
||||
args["grad_ckpt_num"]="160"
|
||||
args["exp_group"]=""
|
||||
args["ignore_cuda_oom"]="1"
|
||||
args["tensorboard_all_tasks"]="0"
|
||||
args["stop_when_end"]="0"
|
||||
args["only_run_dataloader"]="0"
|
||||
args["eps"]="1e-6"
|
||||
args["inspect_iters"]="100"
|
||||
args["strict_state_dict"]="1"
|
||||
args["only_load_model"]="1"
|
||||
args["lr_scheduler"]="cosine"
|
||||
args["resume_no_optimze"]="0"
|
||||
args["tp_size"]="1"
|
||||
args["parallel_load_datastate"]="16"
|
||||
args["async_save"]="False"
|
||||
args["load_dataloader_ckpt"]="0"
|
||||
args["drop_begin"]="-1"
|
||||
args["drop_rate"]="0.5"
|
||||
args["use_checkpoint"]="1"
|
||||
|
||||
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# 使用 Python 读取 JSON 文件并更新 Bash 字典
|
||||
while read -r key value; do
|
||||
args["$key"]="$value"
|
||||
done < <(python -c 'import json, sys; obj = json.load(open("train_configs/'${args['config']}'.json"))["pretrain"]; print("\n".join(["{} {}".format(k, v) for k, v in obj.items()]))')
|
||||
|
||||
|
||||
|
||||
# 用cmd arg 再更新一次
|
||||
# Loop through the arguments
|
||||
for ((i=1; i<=$#; i++)); do
|
||||
arg="${!i}"
|
||||
# Check if the argument starts with "--"
|
||||
if [[ "$arg" == --* ]]; then
|
||||
arg_name="${arg:2}" # Remove leading "--"
|
||||
valueid=$((i+1))
|
||||
|
||||
# Get the value of the argument if it exists
|
||||
if ((i+1 <= $#)); then
|
||||
args["$arg_name"]="${!valueid}"
|
||||
i=$((i+1)) # Skip the next argument (its value)
|
||||
else
|
||||
args["$arg_name"]="" # Set empty value if no value provided
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Print the values of the arguments
|
||||
echo "----------- CMD args ----------"
|
||||
for key in "${!args[@]}"; do
|
||||
echo "$key: ${args[$key]}"
|
||||
done
|
||||
echo "--------- END CMD args --------"
|
||||
|
||||
|
||||
if [[ ${args["flash"]} == "triton" ]]; then
|
||||
sudo cp /usr/local/cuda-11.6/compat/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03
|
||||
sudo ln /usr/lib/x86_64-linux-gnu/libcuda.so.510.108.03 /usr/lib/x86_64-linux-gnu/libcuda.so
|
||||
echo "triton flash"
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
GPUS_PER_NODE=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l)
|
||||
# GPUS_PER_NODE=1
|
||||
echo "Using ${GPUS_PER_NODE} GPU each machine"
|
||||
|
||||
|
||||
if [[ ${args["model_unique"]} == "" ]]; then
|
||||
MODEL_UNIQUE=${JEEVES_JOB_ID} # 写入的位置,没传的话自动构造
|
||||
# JOBID+CreateTime, 本次run的唯一标识符。在白箱里可以通过/projects/${PROJECTID}-${PROJECTNAME}/checkpoints/${MODEL_UNIQUE} 拿到 checkpoint
|
||||
# 通过/projects/${PROJECTID}-${PROJECTNAME}/tensorboard/${MODEL_UNIQUE} 拿到 tensorboard
|
||||
else
|
||||
MODEL_UNIQUE=${args["model_unique"]} # 给了写入的位置
|
||||
fi
|
||||
echo "model_unique: "$MODEL_UNIQUE
|
||||
|
||||
# --------------- 运行参数 ---------------
|
||||
|
||||
OPTS+=" --model-config model_configs/"${args['config']}".json" # [CHANGE]
|
||||
OPTS+=" --batch-size ${args["batch_size"]}"
|
||||
OPTS+=" --train-iters ${args["train_iters"]}"
|
||||
OPTS+=" --save-iters ${args["save_iters"]}"
|
||||
OPTS+=" --save-name fm9g_live_checkpoint"
|
||||
OPTS+=" --max-length ${args["max_length"]}"
|
||||
OPTS+=" --lr ${args["lr"]}"
|
||||
OPTS+=" --inspect-iters ${args["inspect_iters"]}"
|
||||
OPTS+=" --warmup-iters ${args["warmup_iters"]}"
|
||||
OPTS+=" --drop-iters ${args["drop_iters"]}"
|
||||
OPTS+=" --lr_scheduler ${args["lr_scheduler"]}"
|
||||
OPTS+=" --offload"
|
||||
OPTS+=" --vocab ./tokenizer/vocab.txt"
|
||||
OPTS+=" --flash ${args["flash"]}"
|
||||
OPTS+=" --tensorboard_all_tasks ${args["tensorboard_all_tasks"]}"
|
||||
OPTS+=" --ignore_cuda_oom ${args["ignore_cuda_oom"]}"
|
||||
OPTS+=" --stop_when_end ${args["stop_when_end"]}"
|
||||
OPTS+=" --only_run_dataloader ${args["only_run_dataloader"]}"
|
||||
OPTS+=" --eps ${args["eps"]}"
|
||||
OPTS+=" --strict_state_dict ${args["strict_state_dict"]}"
|
||||
OPTS+=" --only_load_model ${args["only_load_model"]}"
|
||||
OPTS+=" --resume_no_optimze ${args["resume_no_optimze"]}"
|
||||
OPTS+=" --tokenizer_path ${args["tokenizer_path"]}"
|
||||
OPTS+=" --weight-decay 0.1"
|
||||
OPTS+=" --tp-size ${args["tp_size"]}"
|
||||
OPTS+=" --parallel_load_datastate ${args["parallel_load_datastate"]}"
|
||||
OPTS+=" --load_dataloader_ckpt ${args["load_dataloader_ckpt"]}"
|
||||
OPTS+=" --drop_begin ${args["drop_begin"]}"
|
||||
OPTS+=" --drop_rate ${args["drop_rate"]}"
|
||||
OPTS+=" --use_checkpoint ${args["use_checkpoint"]}"
|
||||
|
||||
if [[ ${args["load_grad"]} == "True" ]]; then
|
||||
OPTS+=" --load-grad"
|
||||
OPTS+=" --grad-ckpt-num ${args["grad_ckpt_num"]}"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["async_save"]} == "True" ]]; then
|
||||
OPTS+=" --async_save"
|
||||
fi
|
||||
|
||||
|
||||
if [[ ${args["dataloader"]} == "indexed" ]]; then
|
||||
OPTS+=" --dataloader_num_threads ${args["dataloader_num_threads"]}"
|
||||
OPTS+=" --dataloader_prefetch ${args["dataloader_prefetch"]}"
|
||||
OPTS+=" --dataloader_num_workers ${args["dataloader_num_workers"]}"
|
||||
OPTS+=" --dataloader_prefetch_factor ${args["dataloader_prefetch_factor"]}"
|
||||
fi
|
||||
|
||||
|
||||
# --------------- 写文件路径 ---------------
|
||||
## checkpoint
|
||||
if [[ ${args["save"]} == "True" ]]; then
|
||||
|
||||
OPTS+=" --save ./data/checkpoints/${MODEL_UNIQUE}/"
|
||||
OPTS+=" --save-model ./not_exist/${MODEL_UNIQUE}/"
|
||||
else
|
||||
echo "won't save model"
|
||||
fi
|
||||
|
||||
|
||||
## logs,/local/logs 等价于 ./datalogs(软链)
|
||||
mkdir -p ./data/checkpoints/logs/${MODEL_UNIQUE}
|
||||
OPTS+=" --log-dir ./data/checkpoints/logs/${MODEL_UNIQUE}"
|
||||
OPTS+=" --tensorboard ./data/tensorboard/${args["exp_group"]}${MODEL_UNIQUE}/"
|
||||
|
||||
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
else
|
||||
current_dir=$(pwd)
|
||||
OPTS+=" --dataset ${current_dir}/dataset_configs/${args["dataset_config"]}.json"
|
||||
echo "Platform config:"${PLATFORM_CONFIG_PATH}
|
||||
fi
|
||||
|
||||
|
||||
## checkpoint,兼容 CHECKPOINT 和 LATEST_CHECKPOINT。debug 时建议不加载 checkpoint,启动会比较快
|
||||
if [ "${args["resume_ckpt"]}" != "" ]; then
|
||||
OPTS+=" --load ./data/checkpoints/${MODEL_UNIQUE}/${args["resume_ckpt"]}"
|
||||
else
|
||||
echo "No checkpoint to load"
|
||||
fi
|
||||
|
||||
|
||||
filename="pretrain_dragonfly"
|
||||
|
||||
if [[ ${args["local"]} == "True" ]]; then
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
else
|
||||
PRETRAIN_ENTRY="$filename.py"
|
||||
fi
|
||||
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
NNODES=1
|
||||
RANK=0
|
||||
MASTER_ENDPOINT=g3006
|
||||
MASTER_PORT=12345
|
||||
#CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --master_addr=${MASTER_ENDPOINT} --master_port=${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
CMD="torchrun --nnodes=${NNODES} --nproc_per_node=${GPUS_PER_NODE} --node_rank=${RANK} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ENDPOINT}:${MASTER_PORT} ${PRETRAIN_ENTRY} ${OPTS}"
|
||||
|
||||
echo "-------final CMD is------"
|
||||
echo "${CMD}"
|
||||
echo "-------final CMD end------"
|
||||
|
||||
$CMD
|
|
@ -1,9 +0,0 @@
|
|||
{
|
||||
"pretrain": {
|
||||
"train_iters": 20000,
|
||||
"batch_size": 1,
|
||||
"max_length": 4096,
|
||||
"n_gpus": 8,
|
||||
"lr": 1e-5
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
import gc
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import timm
|
||||
import torch
|
||||
from PIL import Image
|
||||
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from torchvision.transforms import InterpolationMode, transforms
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
|
||||
import os,sys
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
||||
from vis_fm9g.generation.vllm_fm9g import VLLMFM9GBeamSearch
|
||||
from vis_fm9g.model.fm9g import FM9GConfig, FM9GTorch
|
||||
from vis_fm9g.model.vlu_fm9g import VLU_FM9G
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
||||
from vis_fm9g.utils.constants import SYSTEM
|
||||
|
||||
|
||||
def chat(model, image, question, context, tokenizer, query_nums=64, vision_hidden_states=None, max_length=1024):
|
||||
if not context:
|
||||
question = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + question
|
||||
final_input = f'{SYSTEM}<用户>{question}<AI>'
|
||||
else:
|
||||
final_input = f'{context}<用户>{question}<AI>'
|
||||
|
||||
data_list = [
|
||||
{'input': final_input}
|
||||
]
|
||||
|
||||
res, vision_hidden_states = model.generate(
|
||||
data_list=data_list,
|
||||
max_inp_length=2048,
|
||||
beam_size=3,
|
||||
img_list=[[image]],
|
||||
max_length=max_length,
|
||||
repetition_penalty=1.1,
|
||||
temperature=0.7,
|
||||
length_penalty=3,
|
||||
return_vision_hidden_states=True
|
||||
)
|
||||
|
||||
answer = res[0]
|
||||
|
||||
context = final_input + answer
|
||||
return answer, context, vision_hidden_states
|
||||
|
||||
|
||||
def load_llm(llm_path):
|
||||
config = FM9GConfig.from_json_file(llm_path)
|
||||
config.use_flash_attn = False
|
||||
cpm_model = FM9GTorch(config)
|
||||
return cpm_model
|
||||
|
||||
|
||||
def load_vpm(vision_encoder, drop_vision_last_layer=False):
|
||||
model = timm.create_model(
|
||||
vision_encoder,
|
||||
pretrained=False,
|
||||
num_classes=0,
|
||||
dynamic_img_size=True,
|
||||
dynamic_img_pad=True
|
||||
)
|
||||
|
||||
if isinstance(model, timm.models.VisionTransformer):
|
||||
if model.attn_pool is not None:
|
||||
model.attn_pool = torch.nn.Identity()
|
||||
|
||||
if drop_vision_last_layer:
|
||||
model.blocks[-1] = torch.nn.Identity()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_vis_fm9g(llm_path, vision_encoder):
|
||||
llm =load_llm(llm_path)
|
||||
vpm = load_vpm(vision_encoder, drop_vision_last_layer=False)
|
||||
|
||||
vision_dim = vpm.embed_dim
|
||||
model = VLU_FM9G(llm, vpm, vision_dim, query_num=256)
|
||||
return model
|
||||
|
||||
|
||||
def load_tokenizer(vocabs_path):
|
||||
return FM9GTokenizer(vocabs_path)
|
||||
|
||||
|
||||
def load_transform(img_size):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD)
|
||||
])
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root = "checkpoint/"
|
||||
llm_path = root + 'config.json'
|
||||
vocabs_path = root + 'vocabs.txt'
|
||||
model_checkpoint = root + 'sharded'
|
||||
vision_encoder = 'eva02_enormous_patch14_clip_224.laion2b_plus'
|
||||
img_size = 448
|
||||
|
||||
|
||||
with init_empty_weights():
|
||||
model = load_vis_fm9g(llm_path, vision_encoder)
|
||||
|
||||
model = load_checkpoint_and_dispatch(model, model_checkpoint, device_map="auto", max_memory={0: "24GiB", 1: "24GiB"}, no_split_module_classes=['EvaBlockPostNorm'])
|
||||
model.eval()
|
||||
|
||||
tokenizer = load_tokenizer(vocabs_path)
|
||||
transform = load_transform(img_size)
|
||||
beam_search = VLLMFM9GBeamSearch(model, tokenizer, transform)
|
||||
# 图像输入
|
||||
url = 'test.jpg'
|
||||
image = Image.open(url).convert('RGB')
|
||||
# 文本输入
|
||||
prompt = '这幅图描述了什么?'
|
||||
answer, context, _ = chat(
|
||||
beam_search, image, prompt, context=None, tokenizer=tokenizer, query_nums=256
|
||||
)
|
||||
|
||||
print(answer)
|
|
@ -1 +0,0 @@
|
|||
from .arguments import get_args
|
|
@ -1,425 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def add_model_config_args(parser: argparse.ArgumentParser):
|
||||
"""Model arguments"""
|
||||
|
||||
group = parser.add_argument_group("model", "model configuration")
|
||||
group.add_argument("--model-config", type=str, help="model configuration file")
|
||||
group.add_argument("--vocab", type=str, default=None, help="model vocabulary file")
|
||||
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
|
||||
# group.add_argument("--qk_norm", action="store_true", default=False, help="qk layernorm")
|
||||
return parser
|
||||
|
||||
|
||||
def add_training_args(parser: argparse.ArgumentParser):
|
||||
"""Training arguments."""
|
||||
|
||||
group = parser.add_argument_group("train", "training configurations")
|
||||
group.add_argument("--platform-config", type=str, default="platform_config.json", help="Path to platform config")
|
||||
group.add_argument("--dataset", type=str, default="dataset.json", help="Path to dataset")
|
||||
group.add_argument("--val-dataset", type=str, default="dataset.json", help="Path to val dataset")
|
||||
group.add_argument(
|
||||
"--load",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a directory containing a model checkpoint.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--load-grad",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Load the gradient states",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--grad-ckpt-num",
|
||||
type=int,
|
||||
default=0,
|
||||
help="grad file num (only work when --load-grad from files less than world-size )",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--load-start-step",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Load the step state from checkpoints",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--save",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory to save checkpoints to.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--save-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output filename to save checkpoints to.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--save-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory to save model to.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--tensorboard",
|
||||
type=str,
|
||||
default=None,
|
||||
help="tensorboard directory",
|
||||
)
|
||||
|
||||
group.add_argument("--inspect-iters", type=int, default=1000, help="number of inspecting")
|
||||
group.add_argument("--batch-size", type=int, default=32, help="Data Loader batch size")
|
||||
group.add_argument("--num-micro-batches", type=int, default=16)
|
||||
group.add_argument("--clip-grad", type=float, default=1.0, help="gradient clipping")
|
||||
group.add_argument("--grad-accum", type=int, default=1, help="gradient accum steps")
|
||||
group.add_argument(
|
||||
"--train-iters",
|
||||
type=int,
|
||||
default=1000000,
|
||||
help="total number of iterations to train over all training runs",
|
||||
)
|
||||
group.add_argument("--max-length", type=int, default=512, help="max length of input")
|
||||
group.add_argument("--min-length", type=int, default=None, help="only for speed test")
|
||||
|
||||
group.add_argument("--seed", type=int, default=1234, help="random seed for reproducibility")
|
||||
|
||||
# Learning rate.
|
||||
group.add_argument("--lr", type=float, default=1.0e-4, help="initial learning rate")
|
||||
group.add_argument("--lr_scheduler", type=str, default="cosine", help=" learning rate scheduler")
|
||||
|
||||
group.add_argument("--weight-decay", type=float, default=1.0e-2, help="weight decay rate")
|
||||
group.add_argument("--loss-scale", type=float, default=65536, help="loss scale")
|
||||
group.add_argument("--max-loss-scale", type=float, default=float("inf"), help="loss scale")
|
||||
group.add_argument("--min-loss-scale", type=float, default=1, help="loss scale")
|
||||
group.add_argument("--loss-scale-steps", type=float, default=1024, help="loss scale")
|
||||
|
||||
group.add_argument(
|
||||
"--warmup-iters",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-iters",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="percentage of data to warmup on (.01 = 1% of all " "training iters). Default 0.01",
|
||||
)
|
||||
|
||||
group.add_argument("--lr-decay-iters", type=int, default=None, help="lr decay steps")
|
||||
group.add_argument("--start-step", type=int, default=0, help="step to start or continue training")
|
||||
group.add_argument("--concat-data", action="store_true", help="whether we concatenate the dialogues")
|
||||
group.add_argument("--offload", action="store_true", help="whether we use offload_adam")
|
||||
group.add_argument("--new-bmt", action="store_true", help="new bmt without ckpt")
|
||||
group.add_argument("--flash", default="none", choices=["none", "1d", "triton", "cuda"])
|
||||
group.add_argument("--use-jfs-data", action="store_true", help="whether we use juicefs dataset")
|
||||
group.add_argument("--tp-size", default=1, type=int)
|
||||
group.add_argument("--pp-size", default=1, type=int)
|
||||
group.add_argument("--bf16", action="store_true", help="whether we use bf16")
|
||||
group.add_argument("--dataloader_num_threads", default=3, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument("--dataloader_prefetch", default=200, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument("--dataloader_num_workers", default=4, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument("--dataloader_prefetch_factor", default=50, type=int, help="Only useful in indexed dataest.")
|
||||
group.add_argument(
|
||||
"--dataloader",
|
||||
default="indexed",
|
||||
type=str,
|
||||
help="dataloader type, 'indexed' for indexed dataset, 'normal' for normal dataset",
|
||||
)
|
||||
group.add_argument("--stop_when_end", default=0, type=int, help="Whether to stop training when we reach end_iter")
|
||||
group.add_argument(
|
||||
"--data_len_threshold",
|
||||
default=512,
|
||||
type=int,
|
||||
help="If the average length of a sequence is less than this int, mean the sample is biased. ",
|
||||
)
|
||||
group.add_argument(
|
||||
"--only_run_dataloader", default=0, type=int, help="Whether to only run dataloader to check data. "
|
||||
)
|
||||
group.add_argument(
|
||||
"--only_load_model", default=0, type=int, help="Whether to only load a model ckpt, without anything else."
|
||||
)
|
||||
group.add_argument(
|
||||
"--load_dataloader_ckpt", default=1, type=int, help="Whether to only load a model ckpt, without anything else."
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--resume_no_optimze",
|
||||
default=0,
|
||||
type=int,
|
||||
help="The number of steps that does not add optimization after resume",
|
||||
)
|
||||
group.add_argument(
|
||||
"--parallel_load_datastate",
|
||||
default=256,
|
||||
type=int,
|
||||
help="The number of parallel workers to load dataset state",
|
||||
)
|
||||
group.add_argument(
|
||||
"--async_save",
|
||||
action="store_true",
|
||||
help="whether to save artifacts asynchronously",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop_begin",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="The number of steps that starts to drop lr"
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop_rate",
|
||||
default=0.5,
|
||||
type=float,
|
||||
help="The number rate"
|
||||
)
|
||||
group.add_argument(
|
||||
"--use_checkpoint",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Whether to use checkpointing."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_pretrain_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("pretrain", "pretrain configurations")
|
||||
group.add_argument(
|
||||
"--save-iters",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="number of iterations between saves",
|
||||
)
|
||||
group.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="log directory",
|
||||
)
|
||||
group.add_argument(
|
||||
"--worker-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="worker name",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_tokenizer_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("tokenizer", "tokenizer configurations")
|
||||
group.add_argument(
|
||||
"--tokenizer_path",
|
||||
type=str,
|
||||
default="",
|
||||
help="tokenizer_path",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_finetune_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("finetune", "finetune configurations")
|
||||
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
|
||||
group.add_argument("--task-name", type=str, default="task", help="name of training task")
|
||||
group.add_argument("--save-epochs", type=int, default=1, help="number of training epochs between saves")
|
||||
group.add_argument("--save-steps", type=int, default=0, help="number of training steps between saves")
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="drop data from each epoch that cannot be formed into a complete batch at the end",
|
||||
)
|
||||
group.add_argument("--delta-tuning", action="store_true", default=False)
|
||||
group.add_argument("--each-epoch-save", default=False)
|
||||
group.add_argument("--train-task-id", type=int, default=-1)
|
||||
return parser
|
||||
|
||||
|
||||
def add_rhlf_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("rhlf", "rhlf configurations")
|
||||
|
||||
group.add_argument(
|
||||
"--load-reward",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to reward model checkpoint.",
|
||||
)
|
||||
group.add_argument("--actor-lr", type=float, default=1.0e-5, help="actor learning rate")
|
||||
group.add_argument("--critic-lr", type=float, default=1.0e-6, help="critic learning rate")
|
||||
group.add_argument("--actor-loss-scale", type=float, default=65536, help="actor loss scale")
|
||||
group.add_argument("--critic-loss-scale", type=float, default=65536, help="critic loss scale")
|
||||
group.add_argument("--avg-reward-bias", type=float, default=0, help="reward bias")
|
||||
group.add_argument("--actor-delay-step", type=int, default=0, help="actor delay step")
|
||||
group.add_argument("--entropy-coef", type=float, default=-1.0, help="coef of policy entropy")
|
||||
##
|
||||
return parser
|
||||
|
||||
|
||||
def add_simple_rhlf_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("simple_rhlf", "simple rhlf configurations")
|
||||
group.add_argument("--epoch", type=int, default=1, help="number of training epochs")
|
||||
group.add_argument("--sample-batch-size", type=int, default=32, help="Data Loader sample batch size")
|
||||
group.add_argument("--load-reward", type=str, default=None, help="Path to reward model checkpoint")
|
||||
group.add_argument("--avg-reward-bias", type=float, default=0, help="reward bias")
|
||||
group.add_argument("--sample-min-length", type=int, default=20, help="sample-min-length")
|
||||
group.add_argument("--sample-max-inp-length", type=int, default=1024, help="sample-max-inp-length")
|
||||
group.add_argument("--sample-max-length", type=int, default=64, help="sample-max-length")
|
||||
group.add_argument("--sample-repetition-penalty", type=float, default=1.05, help="sample-repetition-penalty")
|
||||
group.add_argument("--sample-temperature", type=float, default=1.0, help="sample-temperature")
|
||||
group.add_argument("--encode-max-length", type=int, default=1024, help="encode-max-length")
|
||||
group.add_argument("--generate-max-length", type=int, default=64, help="generate-max-length")
|
||||
group.add_argument("--value-loss-weight", type=float, default=0.1, help="value-loss-weight")
|
||||
group.add_argument("--ptx-loss-weight", type=float, default=0.001, help="ptx-loss-weight")
|
||||
group.add_argument("--save-epochs", type=int, default=1, help="number of training epochs between saves")
|
||||
##
|
||||
return parser
|
||||
|
||||
|
||||
def add_feedback_learning_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("rrhf", "rrhf configurations")
|
||||
group.add_argument("--length-penalty", type=float, default=1.0, help="length_penalty")
|
||||
group.add_argument("--feedback-weight", type=float, default=1.0, help="feedback_weight")
|
||||
group.add_argument("--sample-num", type=int, default=6, help="sample_num")
|
||||
group.add_argument("--dpo-beta", type=float, default=1.0, help="dpo_beta")
|
||||
group.add_argument("--stable-alignment-margin", type=float, default=1.0, help="stable_alignment_margin")
|
||||
group.add_argument("--feedback-learning-type", type=str, default="RRHF", help="feedback_learning_type")
|
||||
group.add_argument("--save-iters", type=int, default=1000, help="number of iterations between saves")
|
||||
##
|
||||
return parser
|
||||
|
||||
|
||||
def add_model_change_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("model_change", "model change during pretraining")
|
||||
group.add_argument("--strict_state_dict", type=int, default=1, help="strict_state_dict")
|
||||
##
|
||||
return parser
|
||||
|
||||
|
||||
def add_log_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("log", "log configurations")
|
||||
group.add_argument("--tensorboard_all_tasks", type=int, default=0, help="log")
|
||||
return parser
|
||||
|
||||
|
||||
def add_error_handle_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("error_handle", "error_handle configurations")
|
||||
group.add_argument(
|
||||
"--ignore_cuda_oom", type=int, default=1, help="continue training by ingore the batch that causes oom"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_runtime_eval_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("runtime eval args", "runtime evaluation by submitting a job")
|
||||
group.add_argument(
|
||||
"--runtime_eval",
|
||||
action="store_true",
|
||||
help="whether to use runtime_eval. Only if this is set to True, the following variables will be useful",
|
||||
)
|
||||
group.add_argument("--eval_jeeves_auth", type=str, default="", help="auth, press f12 on jeeves platform to get")
|
||||
group.add_argument("--eval_project_id", type=str, default=None, help="project id")
|
||||
group.add_argument("--eval_run_cmd", type=str, default="", help="cmd for eval")
|
||||
group.add_argument(
|
||||
"--eval_git_path",
|
||||
type=str,
|
||||
default="git@git.in.zhihu.com:luca/llm-bench.git",
|
||||
help="git path of evaluation code",
|
||||
)
|
||||
group.add_argument("--eval_git_branch", type=str, default="master", help="git branch of evaluation code")
|
||||
group.add_argument("--eval_node_num", type=int, default=1, help="using 1 node to evaluate")
|
||||
group.add_argument("--eval_gpu_num", type=int, default=1, help="using 1 gpu per node to evaluate")
|
||||
group.add_argument("--eval_tasks_config", type=str, default="", help="evaluate tasks' config")
|
||||
group.add_argument("--eval_model_backend", default="torch", type=str, help="model_backend")
|
||||
|
||||
group.add_argument(
|
||||
"--eval_at_start", action="store_true", help="whether to eval at the first epoch, default to false"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_reward_args(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group("reward", "reward configurations")
|
||||
group.add_argument("--load-all", type=str, default=None, help="Path to a directory containing a model checkpoint.")
|
||||
##
|
||||
return parser
|
||||
|
||||
|
||||
def add_long_context_extend_args(parser: argparse.ArgumentParser):
|
||||
"""long context extending arguments."""
|
||||
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
|
||||
group.add_argument("--pose-prob", default=0.0, type=float, help="Sample-level PoSE probability")
|
||||
group.add_argument(
|
||||
"--pose-scaling-factor",
|
||||
default=1.0,
|
||||
type=float,
|
||||
help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor",
|
||||
)
|
||||
group.add_argument(
|
||||
"--rope-scaling-type",
|
||||
default="",
|
||||
type=str,
|
||||
choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],
|
||||
help="Context scaling type",
|
||||
)
|
||||
group.add_argument("--rope-scaling-factor", default=1, type=int, help="Context scaling factor")
|
||||
group.add_argument(
|
||||
"--orig-max-length", default=8192, type=int, help="Original context length before context extending"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_args(
|
||||
pretrain: bool = False,
|
||||
finetune: bool = False,
|
||||
rhlf: bool = False,
|
||||
simple_rlhf: bool = False,
|
||||
feedback_learning: bool = False,
|
||||
reward: bool = False,
|
||||
):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = add_model_config_args(parser) # config file need to be exported with model/ckpt
|
||||
parser = add_training_args(parser)
|
||||
if pretrain:
|
||||
parser = add_pretrain_args(parser)
|
||||
parser = add_runtime_eval_args(parser)
|
||||
parser = add_tokenizer_args(parser)
|
||||
parser = add_log_args(parser)
|
||||
parser = add_error_handle_args(parser)
|
||||
parser = add_model_change_args(parser)
|
||||
|
||||
if finetune:
|
||||
parser = add_finetune_args(parser)
|
||||
if rhlf:
|
||||
parser = add_rhlf_args(parser)
|
||||
if simple_rlhf:
|
||||
parser = add_simple_rhlf_args(parser)
|
||||
if feedback_learning:
|
||||
parser = add_feedback_learning_args(parser)
|
||||
if reward:
|
||||
parser = add_reward_args(parser)
|
||||
parser = add_long_context_extend_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
|
@ -1,16 +0,0 @@
|
|||
from .distributed_dataset import build_dataset
|
||||
from .distributed_dataset import DistributedDataset
|
||||
from .distributed_dataset import SimpleDataset
|
||||
from .indexed_dataset import IndexedDataset
|
||||
from .indexed_dataset import IndexedDatasetBuilder
|
||||
from .indexed_dataset import PrefetchDecodeDataset
|
||||
|
||||
# from .list_dataset import ListDataset
|
||||
from .utils import compact_dataset
|
||||
from .utils import CudaPrefetcher
|
||||
from .utils import mask_dataset
|
||||
from .utils import merge_dataset
|
||||
from .utils import random_range
|
||||
from .utils import Range
|
||||
from .utils import shuffle_dataset
|
||||
from .utils import ThreadedPrefetcher
|
|
@ -1,814 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import struct
|
||||
import time
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
from .serializer import PickleSerializer
|
||||
from .serializer import Serializer
|
||||
|
||||
|
||||
def _random_string():
|
||||
return "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
||||
|
||||
|
||||
_DEFAULT_BLOCK_SIZE = 16 << 20
|
||||
|
||||
|
||||
class FileInfo:
|
||||
def __init__(
|
||||
self,
|
||||
file_name: str = "",
|
||||
block_begin: int = 0,
|
||||
block_end: int = 0,
|
||||
nbytes: int = 0,
|
||||
nlines: int = 0,
|
||||
mask: bool = False,
|
||||
block_size: int = _DEFAULT_BLOCK_SIZE,
|
||||
) -> None:
|
||||
self.file_name = file_name
|
||||
self.block_begin = block_begin
|
||||
self.block_end = block_end
|
||||
self.nbytes = nbytes
|
||||
self.nlines = nlines
|
||||
self.mask = mask
|
||||
self.block_size = block_size
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"file_name": self.file_name,
|
||||
"block_begin": self.block_begin,
|
||||
"block_end": self.block_end,
|
||||
"nbytes": self.nbytes,
|
||||
"nlines": self.nlines,
|
||||
"mask": self.mask,
|
||||
"block_size": self.block_size,
|
||||
}
|
||||
|
||||
def load_state_dict(self, d):
|
||||
self.file_name = d["file_name"]
|
||||
self.block_begin = d["block_begin"]
|
||||
self.block_end = d["block_end"]
|
||||
self.nbytes = d["nbytes"]
|
||||
self.nlines = d["nlines"]
|
||||
self.mask = d["mask"]
|
||||
self.block_size = d["block_size"]
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(self.state_dict())
|
||||
|
||||
def loads(self, data: str) -> "FileInfo":
|
||||
self.load_state_dict(json.loads(data))
|
||||
return self
|
||||
|
||||
def dump(self, fp: io.TextIOWrapper) -> "FileInfo":
|
||||
fp.write(self.dumps())
|
||||
return self
|
||||
|
||||
def load(self, fp: io.TextIOWrapper) -> "FileInfo":
|
||||
self.loads(fp.read())
|
||||
return self
|
||||
|
||||
|
||||
def _read_info_list(meta_path: str) -> List[FileInfo]:
|
||||
info: List[FileInfo] = []
|
||||
while True:
|
||||
try:
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if len(line) > 0:
|
||||
info.append(FileInfo().loads(line))
|
||||
return info
|
||||
except Exception as e:
|
||||
print(
|
||||
"Error: reading info list in _read_info_list!, meta_path={path}, err={err}".format(
|
||||
path=meta_path, err=str(e)
|
||||
)
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def _write_info_list(meta_path: str, info: List[FileInfo]):
|
||||
base_path = os.path.dirname(meta_path)
|
||||
random_fname = os.path.join(base_path, ".meta.bin.%s" % _random_string())
|
||||
while True:
|
||||
try:
|
||||
with open(random_fname, "w", encoding="utf-8") as f:
|
||||
for v in info:
|
||||
f.write(v.dumps() + "\n")
|
||||
os.rename(random_fname, meta_path)
|
||||
return
|
||||
except Exception:
|
||||
print("Error: writing info list!")
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def _filtered_range(begin: int, end: int, rank: int, world_size: int, filter_set: Optional[Set[int]] = None):
|
||||
begin = begin + (rank + (world_size - (begin % world_size))) % world_size
|
||||
|
||||
if filter_set is not None:
|
||||
return [i for i in range(begin, end, world_size) if i in filter_set]
|
||||
else:
|
||||
return [i for i in range(begin, end, world_size)]
|
||||
|
||||
|
||||
class SafeFile:
|
||||
def __init__(self, fname, mode):
|
||||
self.fname = None
|
||||
self.mode = None
|
||||
self._fp = None
|
||||
self.open_file(fname, mode)
|
||||
|
||||
def read(self, size=-1):
|
||||
if self._fp is None:
|
||||
raise RuntimeError("Dataset is closed")
|
||||
try:
|
||||
res = self._fp.read(size)
|
||||
self.offset = self._fp.tell()
|
||||
return res
|
||||
except Exception as e:
|
||||
print("Error: reading blocks in {}! err {}".format(self.fname, str(e)))
|
||||
self.close()
|
||||
self.open_file(self.fname, self.mode, self.offset)
|
||||
return self.read(size)
|
||||
|
||||
def tell(self):
|
||||
if self._fp is None:
|
||||
raise RuntimeError("Dataset is closed")
|
||||
try:
|
||||
res = self._fp.tell()
|
||||
self.offset = res
|
||||
return res
|
||||
except Exception as e:
|
||||
print("Error: telling blocks in {}! err {}".format(self.fname, str(e)))
|
||||
self.close()
|
||||
self.open_file(self.fname, self.mode, self.offset)
|
||||
return self.tell()
|
||||
|
||||
def seek(self, offset, whence=0):
|
||||
if self._fp is None:
|
||||
raise RuntimeError("Dataset is closed")
|
||||
try:
|
||||
res = self._fp.seek(offset, whence)
|
||||
self.offset = self._fp.tell()
|
||||
return res
|
||||
except Exception as e:
|
||||
print("Error: seeking blocks in {}! err {}".format(self.fname, str(e)))
|
||||
self.close()
|
||||
self.open_file(self.fname, self.mode, self.offset)
|
||||
return self.seek(offset, whence)
|
||||
|
||||
def close(self):
|
||||
if self._fp is not None:
|
||||
try:
|
||||
self._fp.close()
|
||||
except Exception as e:
|
||||
print("Error: closing blocks in {}! err {}".format(self.fname, str(e)))
|
||||
self._fp = None
|
||||
|
||||
def open_file(self, fname, mode, offset=None):
|
||||
if not os.path.exists(fname):
|
||||
print("Dataset {} does not exist".format(fname))
|
||||
self.close()
|
||||
time.sleep(20)
|
||||
self.open_file(fname, mode, offset)
|
||||
try:
|
||||
self.fname = fname
|
||||
self.mode = mode
|
||||
self._fp = open(fname, mode)
|
||||
if offset is not None:
|
||||
self._fp.seek(offset, io.SEEK_SET)
|
||||
self.offset = self._fp.tell()
|
||||
except Exception as e:
|
||||
print("Error: opening blocks in {}! err {}".format(self.fname, str(e)))
|
||||
self.close()
|
||||
time.sleep(20)
|
||||
self.open_file(fname, mode, offset)
|
||||
|
||||
|
||||
class DistributedDataset:
|
||||
"""Open dataset in readonly mode.
|
||||
|
||||
`DistributeDataset` is used to read datasets in a distributed manner.
|
||||
Data in this dataset will be distributed evenly in blocks to each worker in the `distributed communicator`.
|
||||
|
||||
**Note** When all data has been read, reading dataset again will revert back to the first data.
|
||||
|
||||
Args:
|
||||
path (str): Path to dataset.
|
||||
rank (int): Rank in distributed communicator. See: bmtrain.rank()
|
||||
world_size (int): Total workers in distributed communicator. See: bmtrain.world_size()
|
||||
block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB
|
||||
|
||||
Example:
|
||||
>>> dataset = DistributedDataset("/path/to/dataset")
|
||||
>>> for i in range(10):
|
||||
>>> dataset.read()
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
serializer: Optional[Serializer] = None,
|
||||
max_repeat_times: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
) -> None:
|
||||
# config
|
||||
self._path = path
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._max_repeat_times = max_repeat_times
|
||||
self._repeat_times = 0
|
||||
self._shuffle = shuffle
|
||||
|
||||
if serializer is None:
|
||||
serializer = PickleSerializer()
|
||||
self.serializer = serializer
|
||||
|
||||
# dataset meta
|
||||
self._unused_block: List[int] = []
|
||||
self._unused_block_offset: List[int] = []
|
||||
self._file_info: List[FileInfo] = []
|
||||
self._file_ends: List[int] = []
|
||||
self._total_blocks = 0
|
||||
self._nbytes = 0
|
||||
self._nlines = 0
|
||||
|
||||
# states
|
||||
self._curr_block = None
|
||||
self._fp = None
|
||||
|
||||
# cache
|
||||
self._last_mod_time = 0
|
||||
self._curr_fname = None
|
||||
|
||||
self._update_states(fast_skip=False)
|
||||
self._repeat_times += 1
|
||||
|
||||
def _update_states(self, fast_skip: bool = True):
|
||||
meta_path = os.path.join(self._path, "meta.bin")
|
||||
|
||||
while True:
|
||||
try:
|
||||
mod_time = os.stat(meta_path).st_mtime
|
||||
break
|
||||
except Exception as e:
|
||||
print(
|
||||
"Error: reading info list in DistributedDataset._update_states, "
|
||||
"meta_path={path}, err={err}!".format(path=meta_path, err=str(e))
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
if self._last_mod_time < mod_time:
|
||||
# file changed
|
||||
self._last_mod_time = mod_time
|
||||
else:
|
||||
if fast_skip:
|
||||
return
|
||||
|
||||
info: List[FileInfo] = []
|
||||
if os.path.exists(meta_path):
|
||||
info = _read_info_list(meta_path)
|
||||
old_len = len(self._file_info)
|
||||
if old_len > len(info):
|
||||
raise RuntimeError("Dataset meta file: changed unexpectly")
|
||||
|
||||
mask_changed = False
|
||||
for i in range(old_len):
|
||||
if self._file_info[i].file_name != info[i].file_name:
|
||||
raise RuntimeError("Dataset meta file: changed unexpectly")
|
||||
if self._file_info[i].block_begin != info[i].block_begin:
|
||||
raise RuntimeError("Dataset meta file: changed unexpectly")
|
||||
if self._file_info[i].block_end != info[i].block_end:
|
||||
raise RuntimeError("Dataset meta file: changed unexpectly")
|
||||
if self._file_info[i].mask != info[i].mask:
|
||||
mask_changed = True
|
||||
|
||||
if info[0].block_begin != 0:
|
||||
raise RuntimeError("Dataset meta file: block error (0)")
|
||||
for i in range(len(info) - 1):
|
||||
if info[i].block_end != info[i + 1].block_begin:
|
||||
raise RuntimeError("Dataset meta file: block error (%d)" % (i + 1))
|
||||
|
||||
if (old_len == len(info) and not mask_changed) and fast_skip:
|
||||
# fast skip
|
||||
return
|
||||
|
||||
if len(info) > 0:
|
||||
total_blocks = info[-1].block_end
|
||||
self._nbytes = 0
|
||||
self._nlines = 0
|
||||
for v in info:
|
||||
self._nbytes += v.nbytes
|
||||
self._nlines += v.nlines
|
||||
else:
|
||||
total_blocks = 0
|
||||
self._nbytes = 0
|
||||
self._nlines = 0
|
||||
|
||||
if total_blocks > 0:
|
||||
unused_block_set = set(self._unused_block)
|
||||
nw_unused_block: List[int] = []
|
||||
for i in range(len(info)):
|
||||
v = info[i]
|
||||
if not v.mask:
|
||||
if i < old_len:
|
||||
nw_unused_block.extend(
|
||||
_filtered_range(
|
||||
v.block_begin,
|
||||
v.block_end,
|
||||
self._rank,
|
||||
self._world_size,
|
||||
unused_block_set,
|
||||
)
|
||||
)
|
||||
else:
|
||||
nw_unused_block.extend(
|
||||
_filtered_range(v.block_begin, v.block_end, self._rank, self._world_size)
|
||||
)
|
||||
|
||||
# re-shuffle unused blocks
|
||||
if self._shuffle:
|
||||
random.shuffle(nw_unused_block)
|
||||
|
||||
offset_dict = {block: offset for block, offset in zip(self._unused_block, self._unused_block_offset)}
|
||||
nw_unused_block_offset = [offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block]
|
||||
self._unused_block = nw_unused_block
|
||||
self._unused_block_offset = nw_unused_block_offset
|
||||
|
||||
self._file_ends = []
|
||||
for v in info:
|
||||
self._file_ends.append(v.block_end)
|
||||
else:
|
||||
self._unused_block = []
|
||||
self._unused_block_offset = []
|
||||
self._file_ends = []
|
||||
|
||||
self._total_blocks = total_blocks
|
||||
self._file_info = info
|
||||
assert len(self._unused_block) == len(self._unused_block_offset)
|
||||
assert len(self._file_ends) == len(self._file_info)
|
||||
|
||||
def _mask_file(self, f: FileInfo):
|
||||
nw_unused_block: List[int] = []
|
||||
nw_unused_block_offset: List[int] = []
|
||||
for block_id, block_offset in zip(self._unused_block, self._unused_block_offset):
|
||||
if block_id < f.block_begin or block_id >= f.block_end:
|
||||
nw_unused_block.append(block_id)
|
||||
nw_unused_block_offset.append(block_offset)
|
||||
self._unused_block = nw_unused_block
|
||||
self._unused_block_offset = nw_unused_block_offset
|
||||
|
||||
def _get_block_file(self, block_id: int):
|
||||
# find block in which file
|
||||
file_idx = bisect.bisect_right(self._file_ends, block_id)
|
||||
return self._file_info[file_idx]
|
||||
|
||||
def _prepare_new_epoch(self):
|
||||
if self._max_repeat_times is not None:
|
||||
if self._repeat_times >= self._max_repeat_times:
|
||||
raise EOFError("End of dataset")
|
||||
nw_unused_block: List[int] = []
|
||||
for v in self._file_info:
|
||||
if not v.mask:
|
||||
nw_unused_block.extend(_filtered_range(v.block_begin, v.block_end, self._rank, self._world_size))
|
||||
if self._shuffle:
|
||||
random.shuffle(nw_unused_block)
|
||||
self._unused_block = nw_unused_block
|
||||
self._unused_block_offset = [0 for _ in nw_unused_block]
|
||||
self._repeat_times += 1
|
||||
|
||||
def _get_next_block(self):
|
||||
self._update_states()
|
||||
if len(self._unused_block) == 0:
|
||||
self._prepare_new_epoch()
|
||||
if len(self._unused_block) == 0:
|
||||
raise RuntimeError("Empty dataset {}".format(self._path))
|
||||
|
||||
mn_block: int = self._unused_block.pop()
|
||||
mn_block_offset: int = self._unused_block_offset.pop()
|
||||
return mn_block, mn_block_offset
|
||||
|
||||
def _state_dict(self):
|
||||
self._update_states()
|
||||
num_unused_block = len(self._unused_block)
|
||||
if (self._fp is not None) and (self._curr_block is not None):
|
||||
curr_block = self._curr_block
|
||||
curr_f = self._get_block_file(curr_block)
|
||||
inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
|
||||
else:
|
||||
curr_block = -1
|
||||
inblock_offset = 0
|
||||
|
||||
return {
|
||||
"states": torch.tensor(self._unused_block, dtype=torch.long, device="cpu"),
|
||||
"offset": torch.tensor(self._unused_block_offset, dtype=torch.long, device="cpu"),
|
||||
"block": torch.tensor(
|
||||
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
),
|
||||
}
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a state dict representing the read states of the dataset.
|
||||
|
||||
Example:
|
||||
>>> state = dataset.state_dict()
|
||||
>>> dataset.load_state_dict(state)
|
||||
"""
|
||||
self._update_states()
|
||||
num_unused_block = len(self._unused_block)
|
||||
if (self._fp is not None) and (self._curr_block is not None):
|
||||
curr_block = self._curr_block
|
||||
curr_f = self._get_block_file(curr_block)
|
||||
inblock_offset = self._fp.tell() - (curr_block - curr_f.block_begin) * curr_f.block_size
|
||||
else:
|
||||
curr_block = -1
|
||||
inblock_offset = 0
|
||||
|
||||
with torch.no_grad():
|
||||
if self._world_size > 1:
|
||||
gpu_num_unused_block = torch.tensor([num_unused_block], dtype=torch.long).cuda()
|
||||
max_unused_blocks = (
|
||||
bmt.distributed.all_reduce(gpu_num_unused_block, op="max", comm=bmt.config["tp_zero_comm"])
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
gpu_states = torch.full((max_unused_blocks,), -1, dtype=torch.long).cuda()
|
||||
gpu_states[:num_unused_block] = torch.tensor(self._unused_block, dtype=torch.long).cuda()
|
||||
gpu_offset = torch.full((max_unused_blocks,), 0, dtype=torch.long).cuda()
|
||||
gpu_offset[:num_unused_block] = torch.tensor(self._unused_block_offset, dtype=torch.long).cuda()
|
||||
gpu_block = torch.tensor(
|
||||
[curr_block, inblock_offset, num_unused_block, self._repeat_times],
|
||||
dtype=torch.long,
|
||||
).cuda()
|
||||
global_states = bmt.distributed.all_gather(
|
||||
gpu_states, comm=bmt.config["tp_zero_comm"]
|
||||
).cpu() # (world_size, max_unused_blocks)
|
||||
global_offset = bmt.distributed.all_gather(
|
||||
gpu_offset, comm=bmt.config["tp_zero_comm"]
|
||||
).cpu() # (world_size, max_unused_blocks)
|
||||
global_block = bmt.distributed.all_gather(
|
||||
gpu_block, comm=bmt.config["tp_zero_comm"]
|
||||
).cpu() # (world_size, 4)
|
||||
return {"states": global_states, "offset": global_offset, "block": global_block}
|
||||
else:
|
||||
return {
|
||||
"states": torch.tensor([self._unused_block], dtype=torch.long, device="cpu"),
|
||||
"offset": torch.tensor([self._unused_block_offset], dtype=torch.long, device="cpu"),
|
||||
"block": torch.tensor(
|
||||
[[curr_block, inblock_offset, num_unused_block, self._repeat_times]],
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state, strict: bool = True):
|
||||
"""Load dataset state.
|
||||
|
||||
Args:
|
||||
state (dict): dataset state dict.
|
||||
strict (bool): If `strict` is True, world size needs to be the same as when exported.
|
||||
|
||||
Example:
|
||||
>>> state = dataset.state_dict()
|
||||
>>>
|
||||
"""
|
||||
block_states: torch.LongTensor = state["states"]
|
||||
block_info: torch.LongTensor = state["block"]
|
||||
if "offset" not in state:
|
||||
block_offset: torch.LongTensor = torch.zeros_like(block_states).long()
|
||||
else:
|
||||
block_offset: torch.LongTensor = state["offset"]
|
||||
|
||||
if block_states.size(0) != self._world_size:
|
||||
if strict:
|
||||
raise ValueError("world_size changed (%d -> %d)" % (state["block"].size(0), self._world_size))
|
||||
else:
|
||||
self._curr_block = None
|
||||
self._fp = None
|
||||
self._curr_fname = None
|
||||
self._repeat_times = int(block_info[0, 3].item())
|
||||
offset_dict = {}
|
||||
for i in range(block_states.size(0)):
|
||||
for block, offset in zip(block_states[i].tolist(), block_offset[i].tolist()):
|
||||
offset_dict[block] = offset
|
||||
|
||||
# re-shuffle unused blocks
|
||||
nw_unused_block: List[int] = []
|
||||
|
||||
for i in range(block_states.size(0)):
|
||||
_, _, num_unused_blocks, _ = block_info[i].tolist()
|
||||
nw_unused_block.extend(
|
||||
[
|
||||
block_id
|
||||
for block_id in block_states[i, :num_unused_blocks].tolist()
|
||||
if block_id % self._world_size == self._rank
|
||||
]
|
||||
)
|
||||
|
||||
for i in range(block_states.size(0)):
|
||||
curr_block, inblock_offset, num_unused_blocks, _ = block_info[i].tolist()
|
||||
if curr_block < 0:
|
||||
continue
|
||||
if curr_block % self._world_size == self._rank:
|
||||
nw_unused_block.append(curr_block)
|
||||
offset_dict[curr_block] = inblock_offset
|
||||
|
||||
curr_block, inblock_offset, _, self._repeat_times = block_info[self._rank].tolist()
|
||||
# if self._shuffle:
|
||||
# random.shuffle(nw_unused_block)
|
||||
nw_unused_block_offset = [
|
||||
offset_dict[block] if block in offset_dict else 0 for block in nw_unused_block
|
||||
]
|
||||
self._unused_block = nw_unused_block
|
||||
self._unused_block_offset = nw_unused_block_offset
|
||||
|
||||
else:
|
||||
curr_block, inblock_offset, num_unused_blocks, self._repeat_times = block_info[self._rank].tolist()
|
||||
if curr_block == -1:
|
||||
self._curr_block = None
|
||||
self._unused_block = []
|
||||
self.unused_block_offset = []
|
||||
else:
|
||||
while True:
|
||||
try:
|
||||
self._curr_block = curr_block
|
||||
f_info = self._get_block_file(self._curr_block)
|
||||
self._open_file(
|
||||
f_info.file_name,
|
||||
(self._curr_block - f_info.block_begin) * f_info.block_size + inblock_offset,
|
||||
)
|
||||
self._unused_block = block_states[self._rank, :num_unused_blocks].tolist()
|
||||
self.unused_block_offset = block_offset[self._rank, :num_unused_blocks].tolist()
|
||||
break
|
||||
except Exception:
|
||||
print("Error: reading blocks in {}".format(f_info.file_name))
|
||||
time.sleep(10)
|
||||
# end
|
||||
self._update_states()
|
||||
|
||||
def _get_file_path(self, fname):
|
||||
return os.path.join(self._path, fname)
|
||||
|
||||
def _open_file(self, fname, offset):
|
||||
if self._curr_fname != fname:
|
||||
if self._fp is not None:
|
||||
self._fp.close()
|
||||
self._curr_fname = None
|
||||
# self._fp = open(self._get_file_path(fname), "rb")
|
||||
self._fp = SafeFile(self._get_file_path(fname), "rb")
|
||||
self._curr_fname = fname
|
||||
else:
|
||||
assert self._fp is not None, "Unexpected error"
|
||||
self._fp.seek(offset, io.SEEK_SET) # move to block
|
||||
|
||||
def read(self):
|
||||
"""Read a piece of data from dataset.
|
||||
|
||||
Workers in different ranks will read different data.
|
||||
"""
|
||||
if self._curr_block is None:
|
||||
next_block_id, next_block_offset = self._get_next_block()
|
||||
f_info = self._get_block_file(next_block_id)
|
||||
try:
|
||||
self._open_file(
|
||||
f_info.file_name, (next_block_id - f_info.block_begin) * f_info.block_size + next_block_offset
|
||||
)
|
||||
self._curr_block = next_block_id
|
||||
except FileNotFoundError:
|
||||
print("ERR: reading again!")
|
||||
self._mask_file(f_info)
|
||||
return self.read() # read again
|
||||
|
||||
if self._fp is None:
|
||||
raise RuntimeError("Dataset is not initialized")
|
||||
|
||||
MAGIC = self._fp.read(1)
|
||||
if MAGIC == b"\x1F":
|
||||
# correct
|
||||
size = struct.unpack("I", self._fp.read(4))[0]
|
||||
data = self._fp.read(size)
|
||||
return self.serializer.deserialize(data)
|
||||
elif MAGIC == b"\x00":
|
||||
# end of block
|
||||
self._curr_block = None
|
||||
return self.read() # read next block
|
||||
else:
|
||||
raise ValueError("Invalid magic header")
|
||||
|
||||
@property
|
||||
def nbytes(self):
|
||||
return self._nbytes
|
||||
|
||||
|
||||
class SimpleDataset(DistributedDataset):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
serializer: Optional[Serializer] = None,
|
||||
shuffle: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
path,
|
||||
0,
|
||||
1,
|
||||
serializer=serializer,
|
||||
max_repeat_times=1,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
try:
|
||||
data = self.read()
|
||||
except EOFError:
|
||||
self._repeat_times = 0
|
||||
break
|
||||
yield data
|
||||
|
||||
def __len__(self):
|
||||
return self._nlines
|
||||
|
||||
def get_bytes(self):
|
||||
return self._nbytes
|
||||
|
||||
|
||||
class DatasetWriter:
|
||||
def __init__(self, fname: str, block_size: int, serializer: Optional[Serializer] = None):
|
||||
self._fname = fname
|
||||
self._block_size = block_size
|
||||
self._fp = open(self._fname, "wb")
|
||||
self._inblock_offset = 0
|
||||
|
||||
self._nbytes = 0
|
||||
self._nlines = 0
|
||||
self._nblocks = 1
|
||||
|
||||
if serializer is None:
|
||||
serializer = PickleSerializer()
|
||||
self.serializer = serializer
|
||||
|
||||
def write(self, data):
|
||||
"""Write a piece of data into dataset.
|
||||
|
||||
Args:
|
||||
data (Any): Serialization will be done using pickle.
|
||||
|
||||
Example:
|
||||
>>> writer.write( "anything you want" )
|
||||
|
||||
"""
|
||||
byte_data = self.serializer.serialize(data)
|
||||
byte_data = struct.pack("I", len(byte_data)) + byte_data
|
||||
if self._inblock_offset + 2 + len(byte_data) > self._block_size:
|
||||
self._fp.write(b"\x00" * (self._block_size - self._inblock_offset)) # fill the remaining space with 0
|
||||
self._inblock_offset = 0
|
||||
self._nblocks += 1
|
||||
# we go to the next block
|
||||
if self._inblock_offset + 2 + len(byte_data) > self._block_size:
|
||||
raise ValueError("data is larger than block size")
|
||||
|
||||
self._nbytes += len(byte_data)
|
||||
self._nlines += 1
|
||||
|
||||
self._inblock_offset += 1 + len(byte_data)
|
||||
self._fp.write(b"\x1F")
|
||||
self._fp.write(byte_data)
|
||||
|
||||
@property
|
||||
def nbytes(self):
|
||||
return self._nbytes
|
||||
|
||||
@property
|
||||
def nblocks(self):
|
||||
return self._nblocks
|
||||
|
||||
@property
|
||||
def nlines(self):
|
||||
return self._nlines
|
||||
|
||||
def close(self):
|
||||
if not self._fp.closed:
|
||||
self._fp.write(b"\x00" * (self._block_size - self._inblock_offset))
|
||||
self._fp.close()
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
dbname: str,
|
||||
block_size=_DEFAULT_BLOCK_SIZE,
|
||||
serializer: Optional[Serializer] = None,
|
||||
) -> None:
|
||||
self._block_size = block_size
|
||||
self._path = path
|
||||
self._dbname = dbname
|
||||
if serializer is None:
|
||||
serializer = PickleSerializer()
|
||||
self.serializer = serializer
|
||||
|
||||
if not os.path.exists(self._path):
|
||||
os.makedirs(self._path)
|
||||
|
||||
meta_path = os.path.join(self._path, "meta.bin")
|
||||
|
||||
info: List[FileInfo] = []
|
||||
if os.path.exists(meta_path):
|
||||
info = _read_info_list(meta_path)
|
||||
|
||||
for v in info:
|
||||
if v.file_name == dbname:
|
||||
raise ValueError("Dataset name exists")
|
||||
|
||||
self._db_path = os.path.join(self._path, self._dbname)
|
||||
if os.path.exists(self._db_path):
|
||||
raise ValueError("File exists `%s`" % self._db_path)
|
||||
|
||||
def __enter__(self):
|
||||
self._writer = DatasetWriter(self._db_path, self._block_size, self.serializer)
|
||||
return self._writer
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if self._writer is None:
|
||||
raise RuntimeError("Unexpected call to __exit__")
|
||||
|
||||
self._writer.close()
|
||||
if exc_type is not None:
|
||||
print("Error while writing file")
|
||||
if os.path.exists(self._db_path):
|
||||
os.unlink(self._db_path)
|
||||
else:
|
||||
meta_path = os.path.join(self._path, "meta.bin")
|
||||
info: List[FileInfo] = []
|
||||
if os.path.exists(meta_path):
|
||||
info = _read_info_list(meta_path)
|
||||
last_block = 0
|
||||
if len(info) > 0:
|
||||
last_block = info[-1].block_end
|
||||
info.append(
|
||||
FileInfo(
|
||||
self._dbname,
|
||||
last_block,
|
||||
last_block + self._writer.nblocks,
|
||||
self._writer.nbytes,
|
||||
self._writer.nlines,
|
||||
False,
|
||||
self._block_size,
|
||||
)
|
||||
)
|
||||
|
||||
# atomic write to meta file
|
||||
_write_info_list(meta_path, info)
|
||||
|
||||
self._writer = None
|
||||
|
||||
|
||||
def build_dataset(
|
||||
path: str,
|
||||
dbname: str,
|
||||
block_size: int = _DEFAULT_BLOCK_SIZE,
|
||||
serializer: Optional[Serializer] = None,
|
||||
):
|
||||
"""Open the dataset in write mode and returns a writer.
|
||||
|
||||
Args:
|
||||
path (str): Path to dataset.
|
||||
dbname (str): The name of the file to which the data will be written. The `dbname` needs to be unique in this `dataset`.
|
||||
block_size (int): Size of each block in bytes. All files in the same dataset should have the same block size. Default: 16MB
|
||||
|
||||
Example:
|
||||
>>> with build_dataset("/path/to/dataset", "data_part_1") as writer:
|
||||
>>> for i in range(10):
|
||||
>>> writer.write( { "anything you want" } )
|
||||
""" # noqa: E501
|
||||
return DatasetBuilder(path, dbname, block_size=block_size, serializer=serializer)
|
|
@ -1,460 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
"""
|
||||
使用 IndexedDataset 前需按指定格式构建或者转换已有数据集
|
||||
数据集文件结构:
|
||||
- <dataset name>
|
||||
- data.jsonl # jsonl 格式的数据,每一行一条样本
|
||||
- index # 记录每一行 json 数据的起始 byte-offset
|
||||
|
||||
从头构建:直接使用 IndexedDatasetBuilder 这个 context manager
|
||||
>>> with IndexedDatasetBuilder("swear", overwrite=True) as builder:
|
||||
>>> for data in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
|
||||
>>> builder.put(data)
|
||||
转换:
|
||||
从 fm9g distributed_dataset 转换:使用 `fm9g.dataset.tools.distributed_to_indexed`
|
||||
$ python -m fm9g.dataset.tools.distributed_to_indexed -i <原数据集文件夹> -o <新数据集文件夹>
|
||||
已有 jsonl 数据:使用 `fm9g.dataset.tools.jsonl_to_index` 构建 index 文件。需提前先把 jsonl 文件命名为
|
||||
$ python -m fm9g.dataset.tools.jsonl_to_index -p <数据集文件夹路径>
|
||||
"""
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
|
||||
import bmtrain as bmt
|
||||
import h5py
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import msgspec
|
||||
|
||||
json_decode = msgspec.json.decode
|
||||
json_encode = msgspec.json.encode
|
||||
except ModuleNotFoundError:
|
||||
import json
|
||||
|
||||
json_decode = json.loads
|
||||
json_encode = json.dumps
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.bitset import bitset_diff
|
||||
|
||||
print_lock = threading.Lock()
|
||||
|
||||
|
||||
def random_range(start, stop=None, step=None):
|
||||
"""
|
||||
Generator of non-repeated random permutation with the same inteface of python
|
||||
`range`. Obtained from https://stackoverflow.com/a/53551417
|
||||
The random.shuffle(list) and random.sample(list, len(list)) require
|
||||
materialize the lists, which result in a long initalization period.
|
||||
"""
|
||||
if stop is None:
|
||||
start, stop = 0, start
|
||||
if step is None:
|
||||
step = 1
|
||||
# Use a mapping to convert a standard range into the desired range.
|
||||
mapping = lambda i: (i * step) + start
|
||||
# Compute the number of numbers in this range.
|
||||
maximum = int(math.ceil((stop - start) / step))
|
||||
if maximum == 0:
|
||||
# early return with empty range
|
||||
yield from ()
|
||||
return
|
||||
# Seed range with a random integer.
|
||||
value = random.randint(0, maximum)
|
||||
# Construct an offset, multiplier, and modulus for a linear
|
||||
# congruential generator. These generators are cyclic and
|
||||
# non-repeating when they maintain the properties:
|
||||
#
|
||||
# 1) "modulus" and "offset" are relatively prime.
|
||||
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
|
||||
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
|
||||
|
||||
# Pick a random odd-valued offset.
|
||||
offset = random.randint(0, maximum) * 2 + 1
|
||||
# Pick a multiplier 1 greater than a multiple of 4.
|
||||
multiplier = 4 * (maximum // 4) + 1
|
||||
# Pick a modulus just big enough to generate all numbers (power of 2).
|
||||
modulus = int(2 ** math.ceil(math.log2(maximum)))
|
||||
# Track how many random numbers have been returned.
|
||||
found = 0
|
||||
while found < maximum:
|
||||
# If this is a valid value, yield it in generator fashion.
|
||||
if value < maximum:
|
||||
found += 1
|
||||
yield mapping(value)
|
||||
# Calculate the next value in the sequence.
|
||||
value = (value * multiplier + offset) % modulus
|
||||
|
||||
|
||||
class Range(object):
|
||||
def __init__(self, start, stop, step):
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def __repr__(self):
|
||||
return f"Range({self.start}, {self.stop}, {self.step})"
|
||||
|
||||
def iterate(self):
|
||||
yield from range(self.start, self.stop, self.step)
|
||||
|
||||
def list(self):
|
||||
return list(range(self.start, self.stop, self.step))
|
||||
|
||||
def subrange(self, split, nsplits):
|
||||
# strided spliting range params
|
||||
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
|
||||
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
|
||||
|
||||
def random_iterate(self):
|
||||
yield from random_range(self.start, self.stop, self.step)
|
||||
|
||||
|
||||
def safe_print(*args, **kargs):
|
||||
if "flush" in kargs:
|
||||
flush = kargs["flush"]
|
||||
del kargs["flush"]
|
||||
else:
|
||||
flush = True
|
||||
with print_lock:
|
||||
print(*args, **kargs, flush=flush)
|
||||
|
||||
|
||||
def concurrent_info():
|
||||
# world_size, rank = bmt.world_size(), bmt.rank()
|
||||
world_size = bmt.config["world_size"] // bmt.config["tp_size"]
|
||||
rank = bmt.config["topology"].tp_idx
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
nworkers, worker_id = 1, 1
|
||||
else:
|
||||
nworkers, worker_id = worker_info.num_workers, worker_info.id
|
||||
# print("concurrent_info: (world_size, rank, nworkers, worker_id): {}".format((world_size, rank, nworkers, worker_id)))
|
||||
return world_size, rank, nworkers, worker_id
|
||||
|
||||
|
||||
class IndexedDataset(Dataset):
|
||||
def __init__(self, path, max_retry=1, retry_sleep=5):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.max_retry = max_retry
|
||||
self.retry_sleep = retry_sleep
|
||||
self.bounds = None
|
||||
self.h5file = None
|
||||
self.build_index()
|
||||
|
||||
def size(self):
|
||||
return self.bounds[-1]
|
||||
|
||||
def _build_index_h5(self):
|
||||
index_path = os.path.join(self.path, "index.h5")
|
||||
if os.path.getsize(index_path) > 104857600:
|
||||
self.h5file = h5py.File(os.path.join(self.path, "index.h5"), "r")
|
||||
self.bounds = self.h5file["index"]
|
||||
else:
|
||||
# only load index into memory when it is small (< 100 Mb)
|
||||
# to avoid keeping to many file handlers
|
||||
self.h5file = None
|
||||
with h5py.File(index_path, "r") as hf:
|
||||
self.bounds = np.array(hf["index"])
|
||||
|
||||
def __del__(self):
|
||||
if self.h5file is not None:
|
||||
self.h5file.close()
|
||||
|
||||
def build_index(self):
|
||||
s = time.time()
|
||||
|
||||
txt_size = os.path.getsize(os.path.join(self.path, "index"))
|
||||
if txt_size > 0.5 * 1024**3 and os.path.exists(os.path.join(self.path, "index.h5")):
|
||||
source = "h5"
|
||||
self._build_index_h5()
|
||||
else:
|
||||
source = "txt"
|
||||
self._build_index_txt()
|
||||
e = time.time()
|
||||
bmt.print_rank("build_index_{} from {}, using {:.2f}s".format(source, self.path, e - s))
|
||||
|
||||
def _build_index_txt(self):
|
||||
with open(os.path.join(self.path, "index"), "r") as fin:
|
||||
self.bounds = [int(line) for line in fin]
|
||||
self.nlines = len(self.bounds)
|
||||
|
||||
def safe_read(self, i_or_s, offset, size):
|
||||
for retry in itertools.count():
|
||||
try:
|
||||
# destroy the file identifier to avoid pressure on alluxio
|
||||
# buffering=0 to avoid overhead during file.seek() and open()
|
||||
with open(os.path.join(self.path, "data.jsonl"), "rb", buffering=0) as fin:
|
||||
fin.seek(offset)
|
||||
raw = fin.read(size)
|
||||
return raw
|
||||
except OSError as e:
|
||||
if retry >= self.max_retry:
|
||||
raise OSError(f"reach maximum #retry: {retry}, the file system is broken.")
|
||||
safe_print(
|
||||
f"retry loading {self.path}:{i_or_s} in {self.retry_sleep} seconds due to error: '{repr(e)}'"
|
||||
)
|
||||
time.sleep(self.retry_sleep)
|
||||
except ValueError as e:
|
||||
# reading error during python io, skip
|
||||
safe_print(f"skipping {self.path}:{i_or_s} due to error: '{repr(e)}'")
|
||||
return None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"IndexedDataset(path={self.path}, max_retry={self.max_retry}, retry_sleep={self.retry_sleep}) "
|
||||
f"with {len(self)} entries."
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.bounds) - 1
|
||||
|
||||
def bound_idx(self, key, strict=False):
|
||||
# bound index within the standard range: [0, len(self))
|
||||
# useful for tracing buggy entries
|
||||
if strict and not (-len(self) <= key < len(self)):
|
||||
raise IndexError(f"Index {key} out of range for '{self.path}'")
|
||||
key = min(max(-len(self), key), len(self)) # bound key within [-len(self), len(self)]
|
||||
key = key if key > 0 else key % len(self) # remap negative id to positive ones
|
||||
return key
|
||||
|
||||
def __getitem__(self, key):
|
||||
# supports list-like slicing and indexing. strided slicing is not currently supported.
|
||||
# ok: self[1], self[-1], self[1:3], self[-10:-5], self[-10:-5:1], self[:5]
|
||||
# not ok: self[-10:-5:2], self[:100:3]
|
||||
if isinstance(key, slice):
|
||||
if not (key.step == 1 or key.step is None):
|
||||
raise ValueError(f"slice step should be 1 or None, not {key.step}")
|
||||
start = self.bound_idx(0 if key.start is None else key.start)
|
||||
stop = max(self.bound_idx(len(self) if key.stop is None else key.stop), start)
|
||||
if stop == start:
|
||||
# early returning empty slice
|
||||
return list()
|
||||
offset, size = self.bounds[start], self.bounds[stop] - self.bounds[start]
|
||||
raw = self.safe_read(key, offset, size)
|
||||
if raw is None:
|
||||
return None
|
||||
else:
|
||||
return [
|
||||
raw[s - offset : e - offset]
|
||||
for s, e in zip(self.bounds[start:stop], self.bounds[start + 1 : stop + 1])
|
||||
]
|
||||
|
||||
elif isinstance(key, int):
|
||||
key = self.bound_idx(key, strict=True)
|
||||
offset, size = self.bounds[key], self.bounds[key + 1] - self.bounds[key]
|
||||
raw = self.safe_read(key, offset, size)
|
||||
return raw
|
||||
else:
|
||||
raise TypeError(f"indices must be integers or slices, not {type(key)}")
|
||||
|
||||
|
||||
class PrefetchDecodeDataset(IndexedDataset):
|
||||
# Add prefetched sampled iterator and state_dict tracking upon the simple IndexedDataset
|
||||
# Add safe decoding in iterator
|
||||
def __init__(self, *args, decode=json_decode, allow_repeat=False, **kargs):
|
||||
super().__init__(*args, **kargs)
|
||||
self.decode = decode
|
||||
self.allow_repeat = allow_repeat
|
||||
|
||||
def safe_decode(self, i, raw):
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return self.decode(raw)
|
||||
except Exception as e:
|
||||
safe_print(f"Skip decoding {self.path}:{i} due to error '{e}', raw bytes:\n{raw}")
|
||||
return None
|
||||
|
||||
def __getitem__(self, key):
|
||||
raw = super().__getitem__(key)
|
||||
if raw is None:
|
||||
return None
|
||||
# key should be either a slice or an integer as checked in IndexedDataset
|
||||
if isinstance(key, slice):
|
||||
return [self.safe_decode(i, r) for i, r in zip(range(key.start, key.stop), raw)]
|
||||
else:
|
||||
return self.safe_decode(key, raw)
|
||||
|
||||
def loader(self, q, lid, keys, stop, used=None):
|
||||
# concurrent prefetching worker
|
||||
if used is None:
|
||||
used = BitSet()
|
||||
try:
|
||||
for key in keys:
|
||||
if stop.is_set():
|
||||
break
|
||||
# key is either a slice or an integer index
|
||||
index = range(key.start, key.stop) if isinstance(key, slice) else [key]
|
||||
unused = bitset_diff(set(index), used)
|
||||
if not unused:
|
||||
# skip used slice / item
|
||||
continue
|
||||
if not q.empty():
|
||||
# avoid breaking the distributed file system with large io load
|
||||
time.sleep(random.random() * 2)
|
||||
# read raw data with IndexedDataset.__getitem__, suspend decoding util we really need it
|
||||
raw = super().__getitem__(key)
|
||||
if raw is None:
|
||||
continue
|
||||
# filter used data
|
||||
items = [(i, s) for i, s in zip(index, raw if len(index) > 1 else [raw]) if i in unused]
|
||||
random.shuffle(items)
|
||||
for item in items:
|
||||
q.put(item)
|
||||
finally:
|
||||
# signaling the end of iteration to the main thread
|
||||
q.put(StopIteration(lid))
|
||||
|
||||
def _iterate(self, key_groups, nprefetch=1000, used=None):
|
||||
# helper function for concurrent prefetching
|
||||
q = queue.Queue(maxsize=nprefetch)
|
||||
stop = threading.Event()
|
||||
alive = set()
|
||||
try:
|
||||
for lid, keys in enumerate(key_groups):
|
||||
loader = threading.Thread(target=self.loader, args=(q, lid, keys, stop, used), daemon=True)
|
||||
loader.start()
|
||||
alive.add(lid)
|
||||
while True:
|
||||
try:
|
||||
item = q.get(block=False)
|
||||
except queue.Empty:
|
||||
if not alive:
|
||||
# no alive loader, thus no item will be put in the queue
|
||||
break
|
||||
else:
|
||||
# new item will be put later, wait for a while
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
if isinstance(item, StopIteration):
|
||||
alive.remove(item.value)
|
||||
continue
|
||||
i, raw = item
|
||||
data = self.safe_decode(i, raw)
|
||||
if data is None:
|
||||
continue
|
||||
yield i, data
|
||||
finally:
|
||||
# ask daemon loaders to stop
|
||||
stop.set()
|
||||
|
||||
def iterate(self, nthreads=3, prefetch_sample=100, used=None, process_group=None):
|
||||
world_size, rank, nworkers, worker_id = concurrent_info(process_group)
|
||||
nloaders = world_size * nworkers * nthreads
|
||||
if len(self) < nloaders:
|
||||
raise ValueError(
|
||||
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||||
f"please constrain either "
|
||||
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||||
)
|
||||
r = Range(0, len(self), 1)
|
||||
# split index among multi-gpu workers
|
||||
r = r.subrange(split=rank, nsplits=world_size)
|
||||
# split index among multi-process dataloader workers
|
||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||
# split index among multi-threaded loaders
|
||||
id_groups = [r.subrange(split=tid, nsplits=nthreads).random_iterate() for tid in range(nthreads)]
|
||||
return self._iterate(id_groups, nprefetch=prefetch_sample, used=used)
|
||||
|
||||
def sliced_iterate(self, nthreads=1, prefetch_slice=3, slice_size=500, used=None):
|
||||
world_size, rank, nworkers, worker_id = concurrent_info()
|
||||
nloaders = world_size * nworkers * nthreads
|
||||
if len(self) < nloaders:
|
||||
if not self.allow_repeat:
|
||||
raise ValueError(
|
||||
f"more concurrent loaders ({nloaders}) than data entries ({len(self)}) in '{self.path}', "
|
||||
f"please constrain either "
|
||||
f"world_size={world_size}, num_workers={nworkers} or num_threads={nthreads}."
|
||||
)
|
||||
else:
|
||||
duplicated_factor = math.ceil(nloaders / len(self))
|
||||
# In this case, slice size is 1
|
||||
r = Range(0, len(self), 1)
|
||||
# split index among grouped multi-gpu workers
|
||||
r = r.subrange(split=rank // duplicated_factor, nsplits=math.ceil(world_size / duplicated_factor))
|
||||
# # split index among multi-threaded loaders
|
||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||
else:
|
||||
nslices = int(math.ceil(len(self) / slice_size))
|
||||
|
||||
if nslices < nloaders:
|
||||
safe_print(
|
||||
f"fail to distribute {nslices} slices from '{self.path}' to {nloaders} concurrent loaders, "
|
||||
f"reduce slice_size from {slice_size} to {len(self) // nloaders}."
|
||||
)
|
||||
slice_size = len(self) // nloaders
|
||||
|
||||
# we only iteratre through start ids as they uniquely mark each slice
|
||||
r = Range(0, len(self), slice_size)
|
||||
# split index among multi-gpu workers
|
||||
r = r.subrange(split=rank, nsplits=world_size)
|
||||
# split index among multi-process dataloader workers
|
||||
r = r.subrange(split=worker_id, nsplits=nworkers)
|
||||
# split index among multi-threaded loaders
|
||||
slice_groups = [
|
||||
(slice(s, s + slice_size) for s in r.subrange(tid, nthreads).random_iterate()) for tid in range(nthreads)
|
||||
]
|
||||
return self._iterate(slice_groups, nprefetch=prefetch_slice * slice_size, used=used)
|
||||
|
||||
|
||||
class IndexedDatasetBuilder:
|
||||
def __init__(self, path, overwrite=False):
|
||||
self.path = path
|
||||
self.index_path = os.path.join(self.path, "index.h5")
|
||||
self.index_path_txt = os.path.join(self.path, "index")
|
||||
self.data_path = os.path.join(self.path, "data.jsonl")
|
||||
if not overwrite:
|
||||
assert not os.path.exists(self.data_path)
|
||||
assert not os.path.exists(self.index_path)
|
||||
assert not os.path.exists(self.index_path_txt)
|
||||
self.fout = None
|
||||
self.bounds = []
|
||||
self.offset = 0
|
||||
|
||||
def __enter__(self):
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
self.fout = open(self.data_path, "wb")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.bounds.append(self.offset)
|
||||
with h5py.File(os.path.join(self.index_path), "w") as hf:
|
||||
hf.create_dataset("index", data=self.bounds)
|
||||
with open(self.index_path_txt, "w") as fout_txt:
|
||||
for s in self.bounds:
|
||||
fout_txt.write(f"{s}\n")
|
||||
self.fout.close()
|
||||
|
||||
def put(self, data: dict):
|
||||
s = json_encode(data) + b"\n"
|
||||
self.bounds.append(self.offset)
|
||||
self.offset += len(s)
|
||||
self.fout.write(s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with IndexedDatasetBuilder("swear", overwrite=True) as builder:
|
||||
for d in [{"input": f"screw it {i}", "output": f"for god's sake {i}"} for i in range(100)]:
|
||||
builder.put(d)
|
||||
dataset = IndexedDataset("swear")
|
||||
for i in range(10):
|
||||
print(dataset[random.randint(0, len(dataset) - 1)])
|
|
@ -1,61 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
|
||||
|
||||
class Serializer:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, obj) -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
def deserialize(self, data: bytes):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PickleSerializer(Serializer):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, obj) -> bytes:
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def deserialize(self, data: bytes):
|
||||
return pickle.loads(data)
|
||||
|
||||
|
||||
class JsonSerializer(Serializer):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, obj) -> bytes:
|
||||
return json.dumps(obj, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
def deserialize(self, data: bytes):
|
||||
return json.loads(data.decode("utf-8"))
|
||||
|
||||
|
||||
class RawSerializer(Serializer):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, obj) -> bytes:
|
||||
return obj
|
||||
|
||||
def deserialize(self, data: bytes):
|
||||
return data
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/08/07
|
|
@ -1,31 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/07/27
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from fm9g.dataset import SimpleDataset
|
||||
from fm9g.dataset.indexed_dataset import IndexedDatasetBuilder
|
||||
|
||||
|
||||
def convert_fm9g_data(fm9g_path, out_path):
|
||||
dataset = SimpleDataset(fm9g_path, shuffle=False)
|
||||
with IndexedDatasetBuilder(out_path, overwrite=True) as builder:
|
||||
for _ in tqdm(range(dataset._nlines), total=dataset._nlines):
|
||||
builder.put(dataset.read())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", "-i", required=True, help="Data path in fm9g format.")
|
||||
parser.add_argument("--output", "-o", required=True, help="Output data path in indexed jsonline format.")
|
||||
args = parser.parse_args()
|
||||
convert_fm9g_data(args.input, args.output)
|
|
@ -1,31 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/08/07
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def build_index(path):
|
||||
data_path = os.path.join(path, "data.jsonl")
|
||||
assert os.path.exists(data_path), f"Jsonline dataset '{data_path}' not found."
|
||||
|
||||
offset = 0
|
||||
starts = [offset]
|
||||
with open(data_path, "rb") as fin:
|
||||
for line in fin:
|
||||
offset += len(line)
|
||||
starts.append(offset)
|
||||
with open(os.path.join(path, "index"), "w") as fout:
|
||||
for s in starts:
|
||||
fout.write(f"{s}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--path", "-p", required=True, help="Data path.")
|
||||
args = parser.parse_args()
|
||||
build_index(args.path)
|
|
@ -1,436 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The OpenBMB team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import struct
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils.log import logger
|
||||
from .distributed_dataset import _DEFAULT_BLOCK_SIZE
|
||||
from .distributed_dataset import _random_string
|
||||
from .distributed_dataset import _read_info_list
|
||||
from .distributed_dataset import _write_info_list
|
||||
from .distributed_dataset import build_dataset
|
||||
from .distributed_dataset import FileInfo
|
||||
from .distributed_dataset import SimpleDataset
|
||||
from .serializer import RawSerializer
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
support_tqdm = True
|
||||
except ModuleNotFoundError:
|
||||
support_tqdm = False
|
||||
|
||||
_DEFAULT_SHUFFLE_BUCKET_SIZE = 1 << 30
|
||||
|
||||
|
||||
def shuffle_dataset(
|
||||
path_src: str,
|
||||
path_tgt: str,
|
||||
block_size: int = _DEFAULT_BLOCK_SIZE,
|
||||
bucket_size: int = _DEFAULT_SHUFFLE_BUCKET_SIZE,
|
||||
progress_bar: bool = False,
|
||||
output_name: Optional[str] = None,
|
||||
):
|
||||
"""Shuffle one distributed datataset, write results to another dataset.
|
||||
|
||||
Args:
|
||||
path_str (str): path to source dataset
|
||||
path_tgt (str): path to write results
|
||||
block_size (int): dataset block size (default: 16MB)
|
||||
bucket_size (int): shuffle algorithm bucket size (default: 1GB)
|
||||
progress_bar (bool): show progress bar
|
||||
|
||||
Example:
|
||||
>>> shuffle_dataset("/path/to/source", "/path/to/output")
|
||||
"""
|
||||
|
||||
if progress_bar and not support_tqdm:
|
||||
raise RuntimeError("Requires `tqdm` to enable progress bar.")
|
||||
|
||||
ds = SimpleDataset(path_src, serializer=RawSerializer())
|
||||
num_buckets = (ds.nbytes + bucket_size - 1) // bucket_size
|
||||
|
||||
tmp_files = [os.path.join(path_src, ".tmp.%s" % _random_string()) for _ in range(num_buckets)]
|
||||
|
||||
try:
|
||||
# Step 1: write to bucket randomly
|
||||
f_tmp = [open(fname, "wb") for fname in tmp_files]
|
||||
try:
|
||||
iterator = ds
|
||||
if progress_bar:
|
||||
iterator = tqdm(ds, desc="Shuffle step 1/2")
|
||||
for data in iterator:
|
||||
bucket_id = int(random.random() * num_buckets)
|
||||
len_data = len(data)
|
||||
f_tmp[bucket_id].write(struct.pack("I", len_data) + data)
|
||||
finally:
|
||||
# close all files
|
||||
for fp in f_tmp:
|
||||
if not fp.closed:
|
||||
fp.close()
|
||||
f_tmp = []
|
||||
|
||||
# Step 2: shuffle inside bucket
|
||||
if output_name is None:
|
||||
output_name = "%s.shuffle" % _random_string()
|
||||
with build_dataset(
|
||||
path_tgt,
|
||||
output_name,
|
||||
block_size=block_size,
|
||||
serializer=RawSerializer(),
|
||||
) as writer:
|
||||
iterator = tmp_files
|
||||
if progress_bar:
|
||||
iterator = tqdm(tmp_files, desc="Shuffle step 2/2")
|
||||
|
||||
for fname in iterator:
|
||||
fp = open(fname, "rb")
|
||||
data_in_bucket = []
|
||||
while True:
|
||||
try:
|
||||
raw_data = fp.read(4)
|
||||
if len(raw_data) == 0:
|
||||
# EOF
|
||||
break
|
||||
len_data = struct.unpack("I", raw_data)[0]
|
||||
data_in_bucket.append(fp.read(len_data))
|
||||
except EOFError:
|
||||
break
|
||||
random.shuffle(data_in_bucket)
|
||||
for data in data_in_bucket:
|
||||
writer.write(data)
|
||||
fp.close()
|
||||
os.unlink(fname)
|
||||
finally:
|
||||
# cleanup
|
||||
for fname in tmp_files:
|
||||
if os.path.exists(fname):
|
||||
os.unlink(fname)
|
||||
|
||||
|
||||
def compact_dataset(path: str):
|
||||
"""Compact the dataset, removes blocks which the files were deleted.
|
||||
|
||||
**Note** This may affect the existing dataset state dict.
|
||||
|
||||
Args:
|
||||
path (str): path to dataset
|
||||
|
||||
Example:
|
||||
>>> compact_dataset("/path/to/dataset")
|
||||
|
||||
"""
|
||||
|
||||
meta_path = os.path.join(path, "meta.bin")
|
||||
|
||||
info: List[FileInfo] = []
|
||||
if os.path.exists(meta_path):
|
||||
info = _read_info_list(meta_path)
|
||||
else:
|
||||
raise ValueError("Dataset not exists")
|
||||
|
||||
nw_info: List[FileInfo] = []
|
||||
curr_block = 0
|
||||
for v in info:
|
||||
if not os.path.exists(v.file_name):
|
||||
# file is deleted
|
||||
pass
|
||||
else:
|
||||
num_file_block = v.block_end - v.block_begin
|
||||
nw_info.append(
|
||||
FileInfo(
|
||||
v.file_name,
|
||||
curr_block,
|
||||
curr_block + num_file_block,
|
||||
v.nbytes,
|
||||
v.nlines,
|
||||
v.mask,
|
||||
v.block_size,
|
||||
)
|
||||
)
|
||||
curr_block += num_file_block
|
||||
|
||||
_write_info_list(meta_path, nw_info)
|
||||
|
||||
|
||||
def mask_dataset(path: str, dbname: str, mask: bool = True):
|
||||
"""Mask one file in dataset. Blocks in masked datasets won't be read later.
|
||||
|
||||
Args:
|
||||
path (str): path to dataset
|
||||
dbname (str): file name in this dataset which you want to mask
|
||||
mask (bool): True for mask, False for unmask
|
||||
|
||||
Example:
|
||||
>>> mask_dataset("/path/to/dataset", "data_part_1", mask=True)
|
||||
|
||||
"""
|
||||
|
||||
meta_path = os.path.join(path, "meta.bin")
|
||||
|
||||
info: List[FileInfo] = []
|
||||
if os.path.exists(meta_path):
|
||||
info = _read_info_list(meta_path)
|
||||
else:
|
||||
raise ValueError("Dataset not exists")
|
||||
|
||||
for v in info:
|
||||
if v.file_name == dbname:
|
||||
v.mask = mask
|
||||
_write_info_list(meta_path, info)
|
||||
|
||||
|
||||
def merge_dataset(dst: str, src: str):
|
||||
meta_path_src = os.path.join(src, "meta.bin")
|
||||
meta_path_dst = os.path.join(dst, "meta.bin")
|
||||
|
||||
info_src: List[FileInfo] = []
|
||||
if os.path.exists(meta_path_src):
|
||||
info_src = _read_info_list(meta_path_src)
|
||||
else:
|
||||
raise ValueError("Dataset not exists")
|
||||
|
||||
info_dst: List[FileInfo] = []
|
||||
if os.path.exists(meta_path_dst):
|
||||
info_dst = _read_info_list(meta_path_dst)
|
||||
else:
|
||||
raise ValueError("Dataset not exists")
|
||||
|
||||
curr_block = 0
|
||||
nw_info: List[FileInfo] = []
|
||||
for v in info_dst:
|
||||
num_file_block = v.block_end - v.block_begin
|
||||
nw_info.append(
|
||||
FileInfo(
|
||||
v.file_name,
|
||||
curr_block,
|
||||
curr_block + num_file_block,
|
||||
v.nbytes,
|
||||
v.nlines,
|
||||
v.mask,
|
||||
v.block_size,
|
||||
)
|
||||
)
|
||||
curr_block += num_file_block
|
||||
|
||||
for v in info_src:
|
||||
num_file_block = v.block_end - v.block_begin
|
||||
|
||||
dst_db_name = os.path.join(dst, v.file_name)
|
||||
nw_fname = v.file_name
|
||||
if os.path.exists(dst_db_name):
|
||||
idx = 0
|
||||
while os.path.exists(dst_db_name + "_{}".format(idx)):
|
||||
idx += 1
|
||||
dst_db_name = dst_db_name + "_{}".format(idx)
|
||||
nw_fname = nw_fname + "_{}".format(idx)
|
||||
|
||||
shutil.copy(os.path.join(src, v.file_name), dst_db_name)
|
||||
nw_info.append(
|
||||
FileInfo(
|
||||
nw_fname,
|
||||
curr_block,
|
||||
curr_block + num_file_block,
|
||||
v.nbytes,
|
||||
v.nlines,
|
||||
v.mask,
|
||||
v.block_size,
|
||||
)
|
||||
)
|
||||
curr_block += num_file_block
|
||||
|
||||
_write_info_list(meta_path_dst, nw_info)
|
||||
|
||||
|
||||
def to_fm9g(src_data, dst_path, dst_name):
|
||||
if not os.path.exists(dst_path):
|
||||
os.makedirs(dst_path)
|
||||
|
||||
logger.info(f"src_data: {src_data}")
|
||||
logger.info(f"dst_path: {dst_path}")
|
||||
logger.info(f"dst_name: {dst_name}")
|
||||
|
||||
tmp_dst_path = dst_path.rstrip("/") + "_tmp"
|
||||
if not os.path.exists(tmp_dst_path):
|
||||
os.makedirs(tmp_dst_path)
|
||||
|
||||
logger.info(f"write binary into: {tmp_dst_path}")
|
||||
with build_dataset(tmp_dst_path, dst_name) as dataset:
|
||||
if os.path.isdir(src_data):
|
||||
filenames = [os.path.join(src_data, name) for name in os.listdir(src_data)]
|
||||
else:
|
||||
filenames = [src_data]
|
||||
|
||||
n_filenames = len(filenames)
|
||||
for idx, filename in enumerate(filenames):
|
||||
logger.info(f"deal: [{n_filenames} -> {idx}] {filename}")
|
||||
if not os.path.exists(filename):
|
||||
logger.error(f"not exist: {filename}")
|
||||
continue
|
||||
|
||||
with open(filename, "r", encoding="utf-8") as fin:
|
||||
for line in fin:
|
||||
line = line.strip()
|
||||
dataset.write(json.loads(line))
|
||||
|
||||
logger.info(f"shuffle binary data from {tmp_dst_path} to {dst_path}")
|
||||
shuffle_dataset(tmp_dst_path, dst_path, progress_bar=True, output_name=dst_name)
|
||||
|
||||
if os.path.exists(tmp_dst_path):
|
||||
shutil.rmtree(tmp_dst_path)
|
||||
|
||||
|
||||
def random_range(start, stop=None, step=None):
|
||||
"""
|
||||
Generator of non-repeated random permutation with the same inteface of python
|
||||
`range`. Obtained from https://stackoverflow.com/a/53551417
|
||||
The random.shuffle(list) and random.sample(list, len(list)) require
|
||||
materialize the lists, which result in a long initalization period.
|
||||
"""
|
||||
if stop is None:
|
||||
start, stop = 0, start
|
||||
if step is None:
|
||||
step = 1
|
||||
# Use a mapping to convert a standard range into the desired range.
|
||||
mapping = lambda i: (i * step) + start
|
||||
# Compute the number of numbers in this range.
|
||||
maximum = int(math.ceil((stop - start) / step))
|
||||
if maximum == 0:
|
||||
# early return with empty range
|
||||
yield from ()
|
||||
return
|
||||
# Seed range with a random integer.
|
||||
value = random.randint(0, maximum)
|
||||
# Construct an offset, multiplier, and modulus for a linear
|
||||
# congruential generator. These generators are cyclic and
|
||||
# non-repeating when they maintain the properties:
|
||||
#
|
||||
# 1) "modulus" and "offset" are relatively prime.
|
||||
# 2) ["multiplier" - 1] is divisible by all prime factors of "modulus".
|
||||
# 3) ["multiplier" - 1] is divisible by 4 if "modulus" is divisible by 4.
|
||||
|
||||
# Pick a random odd-valued offset.
|
||||
offset = random.randint(0, maximum) * 2 + 1
|
||||
# Pick a multiplier 1 greater than a multiple of 4.
|
||||
multiplier = 4 * (maximum // 4) + 1
|
||||
# Pick a modulus just big enough to generate all numbers (power of 2).
|
||||
modulus = int(2 ** math.ceil(math.log2(maximum)))
|
||||
# Track how many random numbers have been returned.
|
||||
found = 0
|
||||
while found < maximum:
|
||||
# If this is a valid value, yield it in generator fashion.
|
||||
if value < maximum:
|
||||
found += 1
|
||||
yield mapping(value)
|
||||
# Calculate the next value in the sequence.
|
||||
value = (value * multiplier + offset) % modulus
|
||||
|
||||
|
||||
class Range(object):
|
||||
def __init__(self, start, stop, step):
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.step = step
|
||||
|
||||
def __repr__(self):
|
||||
return f"Range({self.start}, {self.stop}, {self.step})"
|
||||
|
||||
def iterate(self):
|
||||
yield from range(self.start, self.stop, self.step)
|
||||
|
||||
def list(self):
|
||||
return list(range(self.start, self.stop, self.step))
|
||||
|
||||
def subrange(self, split, nsplits):
|
||||
# strided spliting range params
|
||||
# e.g., [0, 3, 5, 7, 9] can be split into [0, 5, 9] and [3, 7]
|
||||
return Range(self.start + self.step * split, self.stop, self.step * nsplits)
|
||||
|
||||
def random_iterate(self):
|
||||
yield from random_range(self.start, self.stop, self.step)
|
||||
|
||||
|
||||
class CudaPrefetcher(Iterable):
|
||||
"""
|
||||
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||
"""
|
||||
|
||||
def __init__(self, loader):
|
||||
self.loader = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
with torch.cuda.stream(self.stream):
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
data = self.data
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class ThreadedPrefetcher(Thread):
|
||||
def __init__(self, iterable, prefetch=10):
|
||||
"""
|
||||
Wrap around a data iterator to shield io latency with a daemon thread.
|
||||
"""
|
||||
super(ThreadedPrefetcher, self).__init__()
|
||||
self.queue = Queue(maxsize=prefetch)
|
||||
self.iterable = iterable
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
for data in self.iterable:
|
||||
self.queue.put(data)
|
||||
except Exception as exception:
|
||||
self.queue.put(exception)
|
||||
finally:
|
||||
self.queue.put(StopIteration())
|
||||
|
||||
def __next__(self):
|
||||
item = self.queue.get()
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
else:
|
||||
return item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"folders": [
|
||||
{
|
||||
"path": "../.."
|
||||
}
|
||||
],
|
||||
"settings": {}
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class DragonflyConfig(PretrainedConfig):
|
||||
model_type = "fm9g_dragonfly"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_key_value_heads": "num_kv_heads",
|
||||
"hidden_act": "activate_fn",
|
||||
"hidden_size": "dim_model",
|
||||
"num_attention_heads": "num_heads",
|
||||
"intermediate_size": "dim_ff",
|
||||
"num_hidden_layers": "num_layers",
|
||||
"vocab_size": "vocab_size",
|
||||
"rms_norm_eps": "eps",
|
||||
"scale_emb": "scale_emb",
|
||||
"scale_depth": "scale_depth",
|
||||
"scale": "scale",
|
||||
"attention_scale": "attention_scale",
|
||||
"qk_norm": "qk_norm",
|
||||
"ffn_gated": "ffn_gated",
|
||||
} # model specific to common
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=122753, # TODO: do we need to change to 122880 = 960 * 128?
|
||||
dim_model=4096,
|
||||
num_heads=32,
|
||||
num_kv_heads=32,
|
||||
dim_head=128,
|
||||
dim_ff=11008,
|
||||
num_layers=32,
|
||||
dropout_p=0.0,
|
||||
activate_fn="silu",
|
||||
scale=False,
|
||||
scale_emb: float = 1.0,
|
||||
scale_depth: float = -1,
|
||||
dim_model_base: int = 256,
|
||||
eps=1e-5,
|
||||
init_std=0.02,
|
||||
dtype="bf16",
|
||||
base=10000,
|
||||
qk_norm=False,
|
||||
tie_lm_head=False,
|
||||
max_length=8192,
|
||||
pose_prob=0.0,
|
||||
pose_scaling_factor=1,
|
||||
rope_scaling_type="",
|
||||
rope_scaling_factor=1,
|
||||
orig_max_length=8192,
|
||||
tp=0,
|
||||
use_checkpoint=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.dim_model = dim_model
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.dim_head = dim_head
|
||||
self.dim_ff = dim_ff
|
||||
self.num_layers = num_layers
|
||||
self.dropout_p = dropout_p
|
||||
self.activate_fn = activate_fn
|
||||
self.scale = scale
|
||||
self.scale_emb = scale_emb
|
||||
self._dtype = dtype
|
||||
self.dim_model_base = dim_model_base
|
||||
self.scale_depth = scale_depth
|
||||
self.eps = eps
|
||||
self.init_std = init_std
|
||||
self.base = base
|
||||
self.qk_norm = qk_norm
|
||||
self.tie_lm_head = tie_lm_head
|
||||
self.use_bfloat16 = True if self._dtype == "bf16" else False
|
||||
self.pose_prob = pose_prob
|
||||
self.pose_scaling_factor = pose_scaling_factor
|
||||
self.rope_scaling_type = rope_scaling_type
|
||||
self.rope_scaling_factor = rope_scaling_factor
|
||||
self.max_length = max_length
|
||||
self.orig_max_length = orig_max_length
|
||||
self.use_checkpoint = use_checkpoint
|
||||
print("use_checkpoint", self.use_checkpoint)
|
||||
self.tp = tp
|
||||
super().__init__(architectures=["fm9gDragonflyForCausalLM"])
|
||||
|
||||
@property
|
||||
def scale_width(
|
||||
self,
|
||||
):
|
||||
if self.scale:
|
||||
return self.dim_model / self.dim_model_base
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
@property
|
||||
def dtype(
|
||||
self,
|
||||
): # -> Any | None:
|
||||
if self._dtype == "bf16":
|
||||
return torch.bfloat16
|
||||
elif self._dtype == "fp16":
|
||||
return torch.half
|
||||
elif self._dtype == "float32":
|
||||
return torch.float
|
File diff suppressed because it is too large
Load Diff
|
@ -1 +0,0 @@
|
|||
from .pretrain_indexed import MixedIndexedDataset
|
|
@ -1,74 +0,0 @@
|
|||
import logging
|
||||
from multiprocessing import Lock
|
||||
|
||||
from flask import Flask
|
||||
from flask import jsonify
|
||||
from flask import request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 获取 Werkzeug 日志记录器并设置日志级别
|
||||
log = logging.getLogger("werkzeug")
|
||||
log.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class GlobalAvgTokensStat(object):
|
||||
def __init__(self, decay_factor: float = 0.98):
|
||||
self._avg_tokens = {}
|
||||
self.decay_factor = decay_factor
|
||||
self.lock = Lock()
|
||||
self.task_locks = {}
|
||||
|
||||
def set_avg_tokens(self, task_name, avg_tokens):
|
||||
self._register_task_lock_helper(task_name)
|
||||
with self.task_locks[task_name]:
|
||||
self._avg_tokens[task_name] = avg_tokens
|
||||
|
||||
def update_avg_tokens_by_ema(self, task_name, length):
|
||||
self._register_task_lock_helper(task_name)
|
||||
with self.task_locks[task_name]:
|
||||
if task_name in self._avg_tokens and self._avg_tokens[task_name] > 0:
|
||||
self._avg_tokens[task_name] = self._avg_tokens[task_name] * self.decay_factor + length * (
|
||||
1 - self.decay_factor
|
||||
)
|
||||
else:
|
||||
self._avg_tokens[task_name] = length
|
||||
|
||||
def get_avg_tokens(self, task_name):
|
||||
self._register_task_lock_helper(task_name)
|
||||
with self.task_locks[task_name]:
|
||||
return self._avg_tokens.get(task_name, -1)
|
||||
|
||||
def _register_task_lock_helper(self, task_name):
|
||||
if task_name not in self.task_locks:
|
||||
with self.lock:
|
||||
if task_name not in self.task_locks:
|
||||
self.task_locks[task_name] = Lock()
|
||||
|
||||
|
||||
global_avg_tokens_stat = GlobalAvgTokensStat()
|
||||
|
||||
|
||||
@app.route("/avg_tokens/<path:task_name>", methods=["GET"])
|
||||
def get_avg_tokens(task_name):
|
||||
global global_avg_tokens_stat
|
||||
avg_tokens = global_avg_tokens_stat.get_avg_tokens(task_name)
|
||||
return jsonify({"avg_tokens": avg_tokens})
|
||||
|
||||
|
||||
@app.route("/avg_tokens/<path:task_name>", methods=["POST"])
|
||||
def set_avg_tokens(task_name):
|
||||
global global_avg_tokens_stat
|
||||
action = request.args.get("action", "update", type=str)
|
||||
length = request.args.get("length", -1, type=int)
|
||||
if action == "set":
|
||||
global_avg_tokens_stat.set_avg_tokens(task_name, length)
|
||||
elif action == "update":
|
||||
global_avg_tokens_stat.update_avg_tokens_by_ema(task_name, length)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=5000, debug=True)
|
|
@ -1,828 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
import functools
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from fm9g.dataset import PrefetchDecodeDataset
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.vdc_sampling import van_der_corput
|
||||
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
IGNORE_TGT = -100
|
||||
|
||||
|
||||
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||
if cfg_json_str is not None:
|
||||
cfgs = json.loads(cfg_json_str)
|
||||
else:
|
||||
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||
cfgs = json.load(fin)
|
||||
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||
|
||||
path_dict = None
|
||||
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||
try:
|
||||
with open(platform_config_path, "r") as f:
|
||||
platform_cfg = json.load(f)
|
||||
path_dict = platform_cfg["dataset_map"]
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||
except Exception as e:
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||
|
||||
task_name2dataset_name = dict()
|
||||
for idx, cfg in enumerate(cfgs):
|
||||
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||
# to be delibrately annoying :)
|
||||
if cfg["task_name"] in task_name2dataset_name:
|
||||
raise ValueError(
|
||||
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||
)
|
||||
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||
|
||||
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||
# if path_dict is not None:
|
||||
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||
|
||||
# dealing with optional configs
|
||||
if "weight" in cfg:
|
||||
assert isinstance(cfg["weight"], (float, int))
|
||||
else:
|
||||
cfg["weight"] = 1.0
|
||||
|
||||
if "oversize_rule" in cfg:
|
||||
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||
else:
|
||||
cfg["oversize_rule"] = "segment"
|
||||
|
||||
if "transforms" in cfg:
|
||||
assert isinstance(cfg["transforms"], str)
|
||||
# dealing with relative path
|
||||
if not cfg["transforms"].startswith("/"):
|
||||
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||
if not cfg["transforms"]:
|
||||
cfg["transforms"] = None
|
||||
else:
|
||||
cfg["transforms"] = None
|
||||
|
||||
if "incontext_weight" in cfg:
|
||||
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||
else:
|
||||
cfg["incontext_weight"] = [1.0]
|
||||
cfg["id"] = idx
|
||||
# dataset and iterator will be built
|
||||
return cfgs
|
||||
|
||||
|
||||
def data2ids(data, tokenizer, max_length):
|
||||
text = "\n".join(
|
||||
[
|
||||
data.get("title", "").strip(),
|
||||
data.get("question", "").strip(),
|
||||
data.get("answer", "").strip(),
|
||||
data.get("abstract", "").strip(),
|
||||
data.get("text", "").strip(),
|
||||
data.get("code", "").strip(),
|
||||
]
|
||||
).strip()
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||
yield from ()
|
||||
return
|
||||
# suppress the annoying warning from tokenizer
|
||||
ids = (
|
||||
[tokenizer.bos_token_id]
|
||||
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||
+ [tokenizer.eos_token_id]
|
||||
)
|
||||
src_ids = ids[0:-1]
|
||||
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||
|
||||
if len(src_ids) > max_length:
|
||||
for st in range(0, len(src_ids), max_length):
|
||||
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||
else:
|
||||
yield src_ids, tgt_ids
|
||||
|
||||
|
||||
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||
assert oversize_rule in ("drop", "head", "segment")
|
||||
if data is None:
|
||||
yield from ()
|
||||
return
|
||||
if "output" not in data or not data["output"]:
|
||||
yield from ()
|
||||
return
|
||||
if "input" not in data or data["input"] is None:
|
||||
data["input"] = ""
|
||||
|
||||
src_ids = [tokenizer.bos_token_id]
|
||||
tgt_ids = []
|
||||
has_input = False
|
||||
is_segment_reenter = False
|
||||
|
||||
# Use incremental tokenization to avoid waiting for a long document
|
||||
MAX_CHUNK_LENGTH = max_length * 10
|
||||
for part in ("input", "output"):
|
||||
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||
while l < len(data[part]):
|
||||
try:
|
||||
current_slice = data[part][l:r]
|
||||
if not current_slice:
|
||||
break
|
||||
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||
except:
|
||||
print("Error in data[part][l:r] {}".format(data))
|
||||
yield from ()
|
||||
return
|
||||
|
||||
if part == "input":
|
||||
if len(token_ids) > 0:
|
||||
has_input = True
|
||||
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||
yield from ()
|
||||
return
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||
l = r
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
else:
|
||||
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||
if oversize_rule == "drop":
|
||||
yield from ()
|
||||
return
|
||||
elif oversize_rule == "head":
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
elif oversize_rule == "segment":
|
||||
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||
if has_input: # is instruction data
|
||||
if (
|
||||
do_compact
|
||||
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||
and instruction_rest_space / len(src_ids) > 0.8
|
||||
): # can be squeezed into max length
|
||||
inputs_len = len(src_ids)
|
||||
keep_len = instruction_rest_space // 2
|
||||
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else: # else use head rule
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
else: # normal segment
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids)
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
src_ids = src_ids[max_length:]
|
||||
tgt_ids = tgt_ids[max_length:]
|
||||
# sliding input str window
|
||||
consumed_str = tokenizer.decode(selected_token_ids)
|
||||
l += len(consumed_str)
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
is_segment_reenter = True
|
||||
else:
|
||||
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||
not is_segment_reenter and len(token_ids) > 0
|
||||
): # is segmented LM data
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else:
|
||||
yield from ()
|
||||
return
|
||||
|
||||
|
||||
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
tokenizer,
|
||||
max_length=1024,
|
||||
transform_func=None,
|
||||
nthreads=1,
|
||||
prefetch_slice=3,
|
||||
slice_size=500,
|
||||
do_compact=False,
|
||||
):
|
||||
super(SegmentedDataset, self).__init__()
|
||||
self.segment = functools.partial(
|
||||
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||
)
|
||||
self.cfg = cfg
|
||||
self.max_length = max_length
|
||||
self.nthreads = nthreads
|
||||
self.transform_func = transform_func
|
||||
self.prefetch_slice = prefetch_slice
|
||||
self.slice_size = slice_size
|
||||
self.abs_weight = cfg.get("abs_weight", None)
|
||||
self.task_name = cfg["task_name"]
|
||||
self.dataset_name = cfg["dataset_name"]
|
||||
self.oversize_rule = cfg["oversize_rule"]
|
||||
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
|
||||
self.exhausted = False
|
||||
self.iterator = None
|
||||
|
||||
self.counter = 0
|
||||
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||
self.used = BitSet()
|
||||
self.init_ave_tokens()
|
||||
|
||||
def init_ave_tokens(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
except FileNotFoundError:
|
||||
bmt.print_rank(
|
||||
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
)
|
||||
shm = SharedMemory(
|
||||
create=True,
|
||||
size=ctypes.sizeof(ctypes.c_float),
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
|
||||
)
|
||||
|
||||
# 使用共享内存
|
||||
shared_value = ctypes.c_float.from_buffer(shm.buf)
|
||||
_ave_tokens = self.cfg.get(
|
||||
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
|
||||
)
|
||||
|
||||
if _ave_tokens > self.max_length:
|
||||
_ave_tokens = self.max_length
|
||||
bmt.print_rank(
|
||||
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
|
||||
_ave_tokens, self.max_length
|
||||
)
|
||||
)
|
||||
|
||||
shared_value.value = _ave_tokens
|
||||
# 不再需要 shared_value 时,删除引用
|
||||
del shared_value
|
||||
|
||||
# 现在可以安全地关闭共享内存
|
||||
shm.close()
|
||||
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
|
||||
|
||||
@property
|
||||
def ave_tokens(
|
||||
self,
|
||||
):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
tmp = shared_value.value
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
return tmp
|
||||
|
||||
def ave_tokens_update(self, length):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
if shared_value.value < 0:
|
||||
shared_value.value = float(length)
|
||||
else:
|
||||
shared_value.value = 0.98 * shared_value.value + 0.02 * length
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
|
||||
def size(self):
|
||||
return self.dataset.size()
|
||||
|
||||
def __iter__(self):
|
||||
self.iterate()
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.exhausted = False
|
||||
if self.iterator is not None:
|
||||
self.iterator.close()
|
||||
self.iterator = None
|
||||
self.used = BitSet()
|
||||
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||
|
||||
def transform(self, data: dict) -> dict:
|
||||
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||
weight = weight / weight.sum()
|
||||
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||
return self.transform_func(data, num_incontext, random.Random())
|
||||
|
||||
def segment_iterate(self, sample_iter):
|
||||
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||
self.ave_tokens_update(len(src_ids)) # 0 for input ids
|
||||
yield src_ids, tgt_ids, index
|
||||
|
||||
def iterate(self):
|
||||
# make the dataset itself an iterator
|
||||
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||
self.iterator = self.segment_iterate(sample_iter)
|
||||
|
||||
def __next__(self):
|
||||
# advance the task iterator
|
||||
if self.iterator is None:
|
||||
self.iterate()
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if state_dict.get("exhausted", False):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
used = state_dict.get("used", BitSet())
|
||||
if len(used) == len(self.dataset):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
self.exhausted = False
|
||||
self.used = used
|
||||
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
|
||||
|
||||
def state_dict(self):
|
||||
if len(self.used) == len(self.dataset):
|
||||
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
|
||||
else:
|
||||
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
|
||||
|
||||
def update_state(self, indice):
|
||||
self.used.update(indice)
|
||||
|
||||
|
||||
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str,
|
||||
cfg_json_str,
|
||||
tokenizer,
|
||||
max_length,
|
||||
weight_by_size: bool = True,
|
||||
nthreads=5,
|
||||
prefetch_slice=100,
|
||||
parallel_loading=False,
|
||||
vdc_sampling=False,
|
||||
update_weights_frequency=1,
|
||||
seed=42,
|
||||
):
|
||||
super(MixedIndexedDataset, self).__init__()
|
||||
self.set_seed(seed + bmt.rank())
|
||||
self.weight_by_size = weight_by_size
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.bos_token_id = self.tokenizer.bos_token_id
|
||||
self.path2transform = dict()
|
||||
self.task_dict = OrderedDict()
|
||||
self.nthreads = nthreads
|
||||
self.prefetch_slice = prefetch_slice
|
||||
# useful for indexing
|
||||
self.tasks = []
|
||||
self.names = []
|
||||
# ending of iteration
|
||||
self.remain = 0
|
||||
self.max_length = max_length
|
||||
self.vdc_sampling = vdc_sampling
|
||||
if self.vdc_sampling:
|
||||
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
|
||||
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||
|
||||
self.update_weights_frequency = update_weights_frequency
|
||||
|
||||
self.path2transform = dict()
|
||||
|
||||
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||
|
||||
if parallel_loading:
|
||||
self.parallel_load(cfgs, max_workers=None)
|
||||
else:
|
||||
self.sequential_load(cfgs)
|
||||
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def set_seed(self, seed):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def load_task(self, cfg):
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
return task
|
||||
|
||||
def sequential_load(self, cfgs):
|
||||
self.cfgs = cfgs
|
||||
for cfg in cfgs:
|
||||
# python3.7 and later preserves insertion order to dictionary
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
self.task_dict[task.task_name] = task
|
||||
self.tasks.append(task)
|
||||
self.names.append(task.task_name)
|
||||
self.remain += 1
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
missing_keys = []
|
||||
for name, task in self.task_dict.items():
|
||||
if name in state_dict:
|
||||
task.load_state_dict(state_dict[name])
|
||||
else:
|
||||
missing_keys.append(name)
|
||||
self.update_weights()
|
||||
return missing_keys
|
||||
|
||||
def save_state_dict(self, path):
|
||||
state_dict = {}
|
||||
for name, task in self.task_dict.items():
|
||||
_state_dict = task.state_dict()
|
||||
if isinstance(_state_dict["used"], BitSet):
|
||||
bitset = _state_dict["used"]
|
||||
_file_name = bitset.save(path)
|
||||
_state_dict["used"] = _file_name
|
||||
state_dict[name] = _state_dict
|
||||
else:
|
||||
state_dict[name] = task.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
logger.info("Dataset state saved")
|
||||
|
||||
def update_states(self, task_ids, indice):
|
||||
is_dict = isinstance(indice, dict)
|
||||
uniq = torch.unique(task_ids)
|
||||
for idx in uniq:
|
||||
idx = idx.item()
|
||||
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||
self.tasks[idx].update_state(indexes)
|
||||
|
||||
def get_transform_func(self, module_name: str, transform_script_path):
|
||||
if transform_script_path is None:
|
||||
# allow null transform
|
||||
return lambda data, num_incontext, rand: data
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||
if transform_script_path not in self.path2transform:
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
if spec is None:
|
||||
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.path2transform[transform_script_path] = {
|
||||
"loader": loader,
|
||||
"module": mod,
|
||||
"last_mtime": 0,
|
||||
}
|
||||
transform_script_info = self.path2transform[transform_script_path]
|
||||
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||
if curr_mtime > transform_script_info["last_mtime"]:
|
||||
transform_script_info["last_mtime"] = curr_mtime
|
||||
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||
if transform_func is None:
|
||||
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||
return transform_func
|
||||
|
||||
def update_weights(self):
|
||||
task0 = self.tasks[0]
|
||||
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||
weights = []
|
||||
for task in self.tasks:
|
||||
if task.exhausted:
|
||||
weights.append(0)
|
||||
else:
|
||||
if task.ave_tokens == -1:
|
||||
weights.append(task.abs_weight / self.max_length)
|
||||
else:
|
||||
weights.append(task.abs_weight / task.ave_tokens)
|
||||
weights = np.array(weights)
|
||||
else:
|
||||
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||
if self.weight_by_size:
|
||||
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||
weights *= sizes
|
||||
self.weights = weights / weights.sum()
|
||||
|
||||
def __iter__(self):
|
||||
for task in self.tasks:
|
||||
task.iterate()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
step = 1
|
||||
while True:
|
||||
if self.remain == 0:
|
||||
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||
raise StopIteration
|
||||
if self.vdc_sampling:
|
||||
idx = next(self.vdc_gen)(self.weights)
|
||||
else:
|
||||
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||
|
||||
data = next(self.tasks[idx])
|
||||
|
||||
if data is None:
|
||||
if self.tasks[idx].allow_repeat:
|
||||
# _runtime_ave = self.tasks[idx].ave_tokens
|
||||
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
# self.tasks[idx] = SegmentedDataset(
|
||||
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||
# )
|
||||
# self.tasks[idx].ave_tokens_update(_runtime_ave)
|
||||
self.tasks[idx].reset()
|
||||
else:
|
||||
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
self.tasks[idx].exhaust = True
|
||||
self.remain -= 1
|
||||
continue
|
||||
|
||||
if step % self.update_weights_frequency == 0:
|
||||
self.update_weights()
|
||||
step += 1
|
||||
return dict(
|
||||
task_id=idx,
|
||||
input=data[0],
|
||||
target=data[1],
|
||||
index=data[2],
|
||||
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||
)
|
||||
|
||||
|
||||
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||
self.max_total_length = batch_size * max_length
|
||||
self.batch_size = 1
|
||||
# setting compact=True concats segments orignated from the same input
|
||||
# into a long sequence. the relative order of segments should be preserved
|
||||
# in mixed_dataset, e.g.,
|
||||
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||
self.compact = compact
|
||||
|
||||
self.total_length = 0
|
||||
self.task2seqs = defaultdict(list)
|
||||
self.mixed_dataset = mixed_dataset
|
||||
self._max_length = max_length
|
||||
self._pose_prob = pose_prob
|
||||
self._pose_scaling_factor = pose_scaling_factor
|
||||
if self._pose_prob > 0.0:
|
||||
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||
else:
|
||||
self._scaled_max_length = max_length
|
||||
|
||||
def put(self, sample):
|
||||
self.total_length += len(sample["target"])
|
||||
task_id = sample["task_id"]
|
||||
if self.compact and self.task2seqs[task_id]:
|
||||
last = self.task2seqs[task_id][-1]
|
||||
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||
# concatenate sequantial segments for longer context modeling: why not?
|
||||
last["input"].extend(sample["input"])
|
||||
last["target"].extend(sample["target"])
|
||||
return
|
||||
self.task2seqs[task_id].append(sample)
|
||||
|
||||
def _pose_preprocess(
|
||||
self,
|
||||
input_ids: NDArray[np.int32],
|
||||
) -> NDArray[np.int32]:
|
||||
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||
"""
|
||||
len_chunk = min(len(input_ids), self._max_length)
|
||||
len_input = len(input_ids)
|
||||
# Chunk input randomly to fit max_length if needed
|
||||
lt1 = 0
|
||||
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||
# Generate PoSE position ids
|
||||
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||
len_position_ids = len(position_ids)
|
||||
lt = 0
|
||||
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||
position_ids[: rt1 - lt1] += lt
|
||||
position_ids[rt1 - lt1 :] += rt
|
||||
return position_ids
|
||||
|
||||
def pop(self):
|
||||
indexes = defaultdict(list)
|
||||
lengths = []
|
||||
|
||||
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
|
||||
span_begin = 0
|
||||
for samples in self.task2seqs.values():
|
||||
while samples:
|
||||
sample = samples.pop()
|
||||
span_end = span_begin + len(sample["input"])
|
||||
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||
else:
|
||||
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||
lengths.append(len(sample["target"]))
|
||||
indexes[int(sample["task_id"])].append(sample["index"])
|
||||
self.total_length -= len(sample["target"])
|
||||
span_begin = span_end
|
||||
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||
dim=0,
|
||||
).int()
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"targets": targets,
|
||||
"task_ids": task_ids,
|
||||
"indexes": indexes,
|
||||
# adhere to flash attention interface
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||
"lengths": torch.tensor(sum(lengths)).int(),
|
||||
"task_names": self.mixed_dataset.names,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return batch
|
||||
|
||||
def will_be_full(self, sample):
|
||||
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.mixed_dataset:
|
||||
if self.will_be_full(sample):
|
||||
yield self.pop()
|
||||
self.put(sample)
|
||||
|
||||
|
||||
class CudaPrefetcher(Iterable):
|
||||
"""
|
||||
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||
self.loader = iter(loader)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
if self.tp_rank == 0:
|
||||
data = next(self.loader)
|
||||
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||
d = {}
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
np_cur_data = data[key].cpu().numpy()
|
||||
bs = np_cur_data.tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||
elif isinstance(data[key], np.ndarray):
|
||||
bs = data[key].tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||
else:
|
||||
d[key] = data[key]
|
||||
try:
|
||||
_ = json.dumps(d)
|
||||
except TypeError:
|
||||
print(d)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||
json.dump(d, f)
|
||||
bmt.synchronize()
|
||||
if self.tp_rank != 0:
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||
data = json.load(f)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||
bs = fb.read()
|
||||
offset = 0
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||
data[key][3:]
|
||||
)
|
||||
offset = nw_offset
|
||||
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = torch.from_numpy(
|
||||
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||
.reshape(data[key][3:])
|
||||
.copy()
|
||||
)
|
||||
offset = nw_offset
|
||||
self.data = data
|
||||
else:
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
with torch.cuda.stream(self.stream):
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key].record_stream(torch.cuda.current_stream())
|
||||
data = copy.deepcopy(self.data)
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
|
@ -1,829 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
import functools
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from fm9g.dataset import PrefetchDecodeDataset
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.vdc_sampling import van_der_corput
|
||||
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
IGNORE_TGT = -100
|
||||
|
||||
|
||||
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||
if cfg_json_str is not None:
|
||||
cfgs = json.loads(cfg_json_str)
|
||||
else:
|
||||
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||
cfgs = json.load(fin)
|
||||
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||
|
||||
path_dict = None
|
||||
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||
try:
|
||||
with open(platform_config_path, "r") as f:
|
||||
platform_cfg = json.load(f)
|
||||
path_dict = platform_cfg["dataset_map"]
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||
except Exception as e:
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||
|
||||
task_name2dataset_name = dict()
|
||||
for idx, cfg in enumerate(cfgs):
|
||||
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||
# to be delibrately annoying :)
|
||||
if cfg["task_name"] in task_name2dataset_name:
|
||||
raise ValueError(
|
||||
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||
)
|
||||
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||
|
||||
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||
# if path_dict is not None:
|
||||
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||
|
||||
# dealing with optional configs
|
||||
if "weight" in cfg:
|
||||
assert isinstance(cfg["weight"], (float, int))
|
||||
else:
|
||||
cfg["weight"] = 1.0
|
||||
|
||||
if "oversize_rule" in cfg:
|
||||
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||
else:
|
||||
cfg["oversize_rule"] = "segment"
|
||||
|
||||
if "transforms" in cfg:
|
||||
assert isinstance(cfg["transforms"], str)
|
||||
# dealing with relative path
|
||||
if not cfg["transforms"].startswith("/"):
|
||||
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||
if not cfg["transforms"]:
|
||||
cfg["transforms"] = None
|
||||
else:
|
||||
cfg["transforms"] = None
|
||||
|
||||
if "incontext_weight" in cfg:
|
||||
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||
else:
|
||||
cfg["incontext_weight"] = [1.0]
|
||||
cfg["id"] = idx
|
||||
# dataset and iterator will be built
|
||||
return cfgs
|
||||
|
||||
|
||||
def data2ids(data, tokenizer, max_length):
|
||||
text = "\n".join(
|
||||
[
|
||||
data.get("title", "").strip(),
|
||||
data.get("question", "").strip(),
|
||||
data.get("answer", "").strip(),
|
||||
data.get("abstract", "").strip(),
|
||||
data.get("text", "").strip(),
|
||||
data.get("code", "").strip(),
|
||||
]
|
||||
).strip()
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||
yield from ()
|
||||
return
|
||||
# suppress the annoying warning from tokenizer
|
||||
ids = (
|
||||
[tokenizer.bos_id]
|
||||
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||
+ [tokenizer.eos_id]
|
||||
)
|
||||
src_ids = ids[0:-1]
|
||||
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||
|
||||
if len(src_ids) > max_length:
|
||||
for st in range(0, len(src_ids), max_length):
|
||||
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||
else:
|
||||
yield src_ids, tgt_ids
|
||||
|
||||
|
||||
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||
assert oversize_rule in ("drop", "head", "segment")
|
||||
if data is None:
|
||||
yield from ()
|
||||
return
|
||||
if "output" not in data or not data["output"]:
|
||||
yield from ()
|
||||
return
|
||||
if "input" not in data or data["input"] is None:
|
||||
data["input"] = ""
|
||||
|
||||
src_ids = [tokenizer.bos_id]
|
||||
tgt_ids = []
|
||||
has_input = False
|
||||
is_segment_reenter = False
|
||||
|
||||
# Use incremental tokenization to avoid waiting for a long document
|
||||
MAX_CHUNK_LENGTH = max_length * 10
|
||||
for part in ("input", "output"):
|
||||
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||
while l < len(data[part]):
|
||||
try:
|
||||
current_slice = data[part][l:r]
|
||||
if not current_slice:
|
||||
break
|
||||
#token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||
token_ids = tokenizer.encode(current_slice)
|
||||
except:
|
||||
#print("Error in data[part][l:r] {}".format(data))
|
||||
yield from ()
|
||||
return
|
||||
|
||||
if part == "input":
|
||||
if len(token_ids) > 0:
|
||||
has_input = True
|
||||
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||
yield from ()
|
||||
return
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||
l = r
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
else:
|
||||
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||
if oversize_rule == "drop":
|
||||
yield from ()
|
||||
return
|
||||
elif oversize_rule == "head":
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
elif oversize_rule == "segment":
|
||||
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||
if has_input: # is instruction data
|
||||
if (
|
||||
do_compact
|
||||
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||
and instruction_rest_space / len(src_ids) > 0.8
|
||||
): # can be squeezed into max length
|
||||
inputs_len = len(src_ids)
|
||||
keep_len = instruction_rest_space // 2
|
||||
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_id)
|
||||
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else: # else use head rule
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
else: # normal segment
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids)
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
src_ids = src_ids[max_length:]
|
||||
tgt_ids = tgt_ids[max_length:]
|
||||
# sliding input str window
|
||||
consumed_str = tokenizer.decode(selected_token_ids)
|
||||
l += len(consumed_str)
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
is_segment_reenter = True
|
||||
else:
|
||||
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||
not is_segment_reenter and len(token_ids) > 0
|
||||
): # is segmented LM data
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_id)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else:
|
||||
yield from ()
|
||||
return
|
||||
|
||||
|
||||
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
tokenizer,
|
||||
max_length=1024,
|
||||
transform_func=None,
|
||||
nthreads=1,
|
||||
prefetch_slice=3,
|
||||
slice_size=500,
|
||||
do_compact=False,
|
||||
):
|
||||
super(SegmentedDataset, self).__init__()
|
||||
self.segment = functools.partial(
|
||||
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||
)
|
||||
self.cfg = cfg
|
||||
self.max_length = max_length
|
||||
self.nthreads = nthreads
|
||||
self.transform_func = transform_func
|
||||
self.prefetch_slice = prefetch_slice
|
||||
self.slice_size = slice_size
|
||||
self.abs_weight = cfg.get("abs_weight", None)
|
||||
self.task_name = cfg["task_name"]
|
||||
self.dataset_name = cfg["dataset_name"]
|
||||
self.oversize_rule = cfg["oversize_rule"]
|
||||
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", True))
|
||||
self.exhausted = False
|
||||
self.iterator = None
|
||||
|
||||
self.counter = 0
|
||||
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||
self.used = BitSet()
|
||||
self.init_ave_tokens()
|
||||
|
||||
def init_ave_tokens(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
except FileNotFoundError:
|
||||
bmt.print_rank(
|
||||
"Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}')
|
||||
)
|
||||
shm = SharedMemory(
|
||||
create=True,
|
||||
size=ctypes.sizeof(ctypes.c_float),
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}',
|
||||
)
|
||||
|
||||
# 使用共享内存
|
||||
shared_value = ctypes.c_float.from_buffer(shm.buf)
|
||||
_ave_tokens = self.cfg.get(
|
||||
"avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
|
||||
)
|
||||
|
||||
if _ave_tokens > self.max_length:
|
||||
_ave_tokens = self.max_length
|
||||
bmt.print_rank(
|
||||
"Warning: avg_tokens {} is larger than max_length {}, set to max_length".format(
|
||||
_ave_tokens, self.max_length
|
||||
)
|
||||
)
|
||||
|
||||
shared_value.value = _ave_tokens
|
||||
# 不再需要 shared_value 时,删除引用
|
||||
del shared_value
|
||||
|
||||
# 现在可以安全地关闭共享内存
|
||||
shm.close()
|
||||
bmt.print_rank("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))
|
||||
|
||||
@property
|
||||
def ave_tokens(
|
||||
self,
|
||||
):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
tmp = shared_value.value
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
return tmp
|
||||
|
||||
def ave_tokens_update(self, length):
|
||||
existing_shm = SharedMemory(
|
||||
name=f'ave_tokens_{self.task_name.replace("/", "_")}_{bmt.rank()}'
|
||||
) # -1 # default length
|
||||
shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
|
||||
if shared_value.value < 0:
|
||||
shared_value.value = float(length)
|
||||
else:
|
||||
shared_value.value = 0.98 * shared_value.value + 0.02 * length
|
||||
del shared_value
|
||||
existing_shm.close()
|
||||
|
||||
def size(self):
|
||||
return self.dataset.size()
|
||||
|
||||
def __iter__(self):
|
||||
self.iterate()
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.exhausted = False
|
||||
if self.iterator is not None:
|
||||
self.iterator.close()
|
||||
self.iterator = None
|
||||
self.used = BitSet()
|
||||
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||
|
||||
def transform(self, data: dict) -> dict:
|
||||
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||
weight = weight / weight.sum()
|
||||
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||
return self.transform_func(data, num_incontext, random.Random())
|
||||
|
||||
def segment_iterate(self, sample_iter):
|
||||
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||
self.ave_tokens_update(len(src_ids)) # 0 for input ids
|
||||
yield src_ids, tgt_ids, index
|
||||
|
||||
def iterate(self):
|
||||
# make the dataset itself an iterator
|
||||
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||
self.iterator = self.segment_iterate(sample_iter)
|
||||
|
||||
def __next__(self):
|
||||
# advance the task iterator
|
||||
if self.iterator is None:
|
||||
self.iterate()
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if state_dict.get("exhausted", False):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
used = state_dict.get("used", BitSet())
|
||||
if len(used) == len(self.dataset):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
self.exhausted = False
|
||||
self.used = used
|
||||
self.ave_tokens_update(state_dict.get("ave_tokens", -1))
|
||||
|
||||
def state_dict(self):
|
||||
if len(self.used) == len(self.dataset):
|
||||
return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
|
||||
else:
|
||||
return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)
|
||||
|
||||
def update_state(self, indice):
|
||||
self.used.update(indice)
|
||||
|
||||
|
||||
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str,
|
||||
cfg_json_str,
|
||||
tokenizer,
|
||||
max_length,
|
||||
weight_by_size: bool = True,
|
||||
nthreads=5,
|
||||
prefetch_slice=100,
|
||||
parallel_loading=False,
|
||||
vdc_sampling=False,
|
||||
update_weights_frequency=1,
|
||||
seed=42,
|
||||
):
|
||||
super(MixedIndexedDataset, self).__init__()
|
||||
self.set_seed(seed + bmt.rank())
|
||||
self.weight_by_size = weight_by_size
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_id
|
||||
self.bos_token_id = self.tokenizer.bos_id
|
||||
self.path2transform = dict()
|
||||
self.task_dict = OrderedDict()
|
||||
self.nthreads = nthreads
|
||||
self.prefetch_slice = prefetch_slice
|
||||
# useful for indexing
|
||||
self.tasks = []
|
||||
self.names = []
|
||||
# ending of iteration
|
||||
self.remain = 0
|
||||
self.max_length = max_length
|
||||
self.vdc_sampling = vdc_sampling
|
||||
if self.vdc_sampling:
|
||||
self._vdc_values = [van_der_corput(i) for i in range(10**6)] # 精度提高 10^{-6}
|
||||
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||
|
||||
self.update_weights_frequency = update_weights_frequency
|
||||
|
||||
self.path2transform = dict()
|
||||
|
||||
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||
|
||||
if parallel_loading:
|
||||
self.parallel_load(cfgs, max_workers=None)
|
||||
else:
|
||||
self.sequential_load(cfgs)
|
||||
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def set_seed(self, seed):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def load_task(self, cfg):
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
return task
|
||||
|
||||
def sequential_load(self, cfgs):
|
||||
self.cfgs = cfgs
|
||||
for cfg in cfgs:
|
||||
# python3.7 and later preserves insertion order to dictionary
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
self.task_dict[task.task_name] = task
|
||||
self.tasks.append(task)
|
||||
self.names.append(task.task_name)
|
||||
self.remain += 1
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
missing_keys = []
|
||||
for name, task in self.task_dict.items():
|
||||
if name in state_dict:
|
||||
task.load_state_dict(state_dict[name])
|
||||
else:
|
||||
missing_keys.append(name)
|
||||
self.update_weights()
|
||||
return missing_keys
|
||||
|
||||
def save_state_dict(self, path):
|
||||
state_dict = {}
|
||||
for name, task in self.task_dict.items():
|
||||
_state_dict = task.state_dict()
|
||||
if isinstance(_state_dict["used"], BitSet):
|
||||
bitset = _state_dict["used"]
|
||||
_file_name = bitset.save(path)
|
||||
_state_dict["used"] = _file_name
|
||||
state_dict[name] = _state_dict
|
||||
else:
|
||||
state_dict[name] = task.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
logger.info("Dataset state saved")
|
||||
|
||||
def update_states(self, task_ids, indice):
|
||||
is_dict = isinstance(indice, dict)
|
||||
uniq = torch.unique(task_ids)
|
||||
for idx in uniq:
|
||||
idx = idx.item()
|
||||
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||
self.tasks[idx].update_state(indexes)
|
||||
|
||||
def get_transform_func(self, module_name: str, transform_script_path):
|
||||
if transform_script_path is None:
|
||||
# allow null transform
|
||||
return lambda data, num_incontext, rand: data
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||
if transform_script_path not in self.path2transform:
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
if spec is None:
|
||||
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.path2transform[transform_script_path] = {
|
||||
"loader": loader,
|
||||
"module": mod,
|
||||
"last_mtime": 0,
|
||||
}
|
||||
transform_script_info = self.path2transform[transform_script_path]
|
||||
curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
|
||||
if curr_mtime > transform_script_info["last_mtime"]:
|
||||
transform_script_info["last_mtime"] = curr_mtime
|
||||
transform_script_info["loader"].exec_module(transform_script_info["module"])
|
||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||
if transform_func is None:
|
||||
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||
return transform_func
|
||||
|
||||
def update_weights(self):
|
||||
task0 = self.tasks[0]
|
||||
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||
weights = []
|
||||
for task in self.tasks:
|
||||
if task.exhausted:
|
||||
weights.append(0)
|
||||
else:
|
||||
if task.ave_tokens == -1:
|
||||
weights.append(task.abs_weight / self.max_length)
|
||||
else:
|
||||
weights.append(task.abs_weight / task.ave_tokens)
|
||||
weights = np.array(weights)
|
||||
else:
|
||||
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||
if self.weight_by_size:
|
||||
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||
weights *= sizes
|
||||
self.weights = weights / weights.sum()
|
||||
|
||||
def __iter__(self):
|
||||
for task in self.tasks:
|
||||
task.iterate()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
step = 1
|
||||
while True:
|
||||
if self.remain == 0:
|
||||
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||
raise StopIteration
|
||||
if self.vdc_sampling:
|
||||
idx = next(self.vdc_gen)(self.weights)
|
||||
else:
|
||||
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||
|
||||
data = next(self.tasks[idx])
|
||||
if data is None:
|
||||
if self.tasks[idx].allow_repeat:
|
||||
# _runtime_ave = self.tasks[idx].ave_tokens
|
||||
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
# self.tasks[idx] = SegmentedDataset(
|
||||
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||
# )
|
||||
# self.tasks[idx].ave_tokens_update(_runtime_ave)
|
||||
self.tasks[idx].reset()
|
||||
else:
|
||||
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
self.tasks[idx].exhaust = True
|
||||
self.remain -= 1
|
||||
continue
|
||||
|
||||
if step % self.update_weights_frequency == 0:
|
||||
self.update_weights()
|
||||
step += 1
|
||||
|
||||
return dict(
|
||||
task_id=idx,
|
||||
input=data[0],
|
||||
target=data[1],
|
||||
index=data[2],
|
||||
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||
)
|
||||
|
||||
|
||||
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||
self.max_total_length = batch_size * max_length
|
||||
self.batch_size = 1
|
||||
# setting compact=True concats segments orignated from the same input
|
||||
# into a long sequence. the relative order of segments should be preserved
|
||||
# in mixed_dataset, e.g.,
|
||||
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||
self.compact = compact
|
||||
|
||||
self.total_length = 0
|
||||
self.task2seqs = defaultdict(list)
|
||||
self.mixed_dataset = mixed_dataset
|
||||
self._max_length = max_length
|
||||
self._pose_prob = pose_prob
|
||||
self._pose_scaling_factor = pose_scaling_factor
|
||||
if self._pose_prob > 0.0:
|
||||
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||
else:
|
||||
self._scaled_max_length = max_length
|
||||
|
||||
def put(self, sample):
|
||||
self.total_length += len(sample["target"])
|
||||
task_id = sample["task_id"]
|
||||
if self.compact and self.task2seqs[task_id]:
|
||||
last = self.task2seqs[task_id][-1]
|
||||
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||
# concatenate sequantial segments for longer context modeling: why not?
|
||||
last["input"].extend(sample["input"])
|
||||
last["target"].extend(sample["target"])
|
||||
return
|
||||
self.task2seqs[task_id].append(sample)
|
||||
|
||||
def _pose_preprocess(
|
||||
self,
|
||||
input_ids: NDArray[np.int32],
|
||||
) -> NDArray[np.int32]:
|
||||
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||
"""
|
||||
len_chunk = min(len(input_ids), self._max_length)
|
||||
len_input = len(input_ids)
|
||||
# Chunk input randomly to fit max_length if needed
|
||||
lt1 = 0
|
||||
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||
# Generate PoSE position ids
|
||||
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||
len_position_ids = len(position_ids)
|
||||
lt = 0
|
||||
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||
position_ids[: rt1 - lt1] += lt
|
||||
position_ids[rt1 - lt1 :] += rt
|
||||
return position_ids
|
||||
|
||||
def pop(self):
|
||||
indexes = defaultdict(list)
|
||||
lengths = []
|
||||
|
||||
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
|
||||
span_begin = 0
|
||||
for samples in self.task2seqs.values():
|
||||
while samples:
|
||||
sample = samples.pop()
|
||||
span_end = span_begin + len(sample["input"])
|
||||
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||
else:
|
||||
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||
lengths.append(len(sample["target"]))
|
||||
indexes[int(sample["task_id"])].append(sample["index"])
|
||||
self.total_length -= len(sample["target"])
|
||||
span_begin = span_end
|
||||
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||
dim=0,
|
||||
).int()
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"targets": targets,
|
||||
"task_ids": task_ids,
|
||||
"indexes": indexes,
|
||||
# adhere to flash attention interface
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||
"lengths": torch.tensor(sum(lengths)).int(),
|
||||
"task_names": self.mixed_dataset.names,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return batch
|
||||
|
||||
def will_be_full(self, sample):
|
||||
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.mixed_dataset:
|
||||
if self.will_be_full(sample):
|
||||
yield self.pop()
|
||||
self.put(sample)
|
||||
|
||||
|
||||
class CudaPrefetcher(Iterable):
|
||||
"""
|
||||
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||
self.loader = iter(loader)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
if self.tp_rank == 0:
|
||||
data = next(self.loader)
|
||||
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||
d = {}
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
np_cur_data = data[key].cpu().numpy()
|
||||
bs = np_cur_data.tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||
elif isinstance(data[key], np.ndarray):
|
||||
bs = data[key].tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||
else:
|
||||
d[key] = data[key]
|
||||
try:
|
||||
_ = json.dumps(d)
|
||||
except TypeError:
|
||||
print(d)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||
json.dump(d, f)
|
||||
bmt.synchronize()
|
||||
if self.tp_rank != 0:
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||
data = json.load(f)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||
bs = fb.read()
|
||||
offset = 0
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||
data[key][3:]
|
||||
)
|
||||
offset = nw_offset
|
||||
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = torch.from_numpy(
|
||||
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||
.reshape(data[key][3:])
|
||||
.copy()
|
||||
)
|
||||
offset = nw_offset
|
||||
self.data = data
|
||||
else:
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
with torch.cuda.stream(self.stream):
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key].record_stream(torch.cuda.current_stream())
|
||||
data = copy.deepcopy(self.data)
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
|
@ -1,827 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: ouzebin <ouzebin@zhihu.com>
|
||||
# @date: 2023/09/27
|
||||
|
||||
|
||||
import copy
|
||||
import ctypes
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Lock
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from tenacity import retry
|
||||
from tenacity import stop_after_attempt
|
||||
from tenacity import wait_fixed
|
||||
from tenacity import wait_random
|
||||
|
||||
from fm9g.dataset import PrefetchDecodeDataset
|
||||
from fm9g.utils.bitset import BitSet
|
||||
from fm9g.utils.vdc_sampling import van_der_corput
|
||||
from fm9g.utils.vdc_sampling import van_der_corput_sampling_gen
|
||||
|
||||
from .flask_ps import app as flask_ps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
IGNORE_TGT = -100
|
||||
|
||||
|
||||
def load_dataset_cfgs(cfg_path, cfg_json_str=None):
|
||||
if cfg_json_str is not None:
|
||||
cfgs = json.loads(cfg_json_str)
|
||||
else:
|
||||
with open(cfg_path, "r", encoding="utf-8") as fin:
|
||||
cfgs = json.load(fin)
|
||||
transform_basedir = os.path.dirname(os.path.abspath(cfg_path))
|
||||
|
||||
path_dict = None
|
||||
platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
|
||||
try:
|
||||
with open(platform_config_path, "r") as f:
|
||||
platform_cfg = json.load(f)
|
||||
path_dict = platform_cfg["dataset_map"]
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
|
||||
except Exception as e:
|
||||
if bmt.rank() == 0:
|
||||
logger.info(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")
|
||||
|
||||
task_name2dataset_name = dict()
|
||||
for idx, cfg in enumerate(cfgs):
|
||||
assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
|
||||
assert "task_name" in cfg and isinstance(cfg["task_name"], str)
|
||||
# to be delibrately annoying :)
|
||||
if cfg["task_name"] in task_name2dataset_name:
|
||||
raise ValueError(
|
||||
f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
|
||||
f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
|
||||
)
|
||||
task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]
|
||||
|
||||
assert "path" in cfg and isinstance(cfg["path"], str)
|
||||
# if path_dict is not None:
|
||||
# cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])
|
||||
|
||||
# dealing with optional configs
|
||||
if "weight" in cfg:
|
||||
assert isinstance(cfg["weight"], (float, int))
|
||||
else:
|
||||
cfg["weight"] = 1.0
|
||||
|
||||
if "oversize_rule" in cfg:
|
||||
assert cfg["oversize_rule"] in ("drop", "head", "segment")
|
||||
else:
|
||||
cfg["oversize_rule"] = "segment"
|
||||
|
||||
if "transforms" in cfg:
|
||||
assert isinstance(cfg["transforms"], str)
|
||||
# dealing with relative path
|
||||
if not cfg["transforms"].startswith("/"):
|
||||
cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
|
||||
if not cfg["transforms"]:
|
||||
cfg["transforms"] = None
|
||||
else:
|
||||
cfg["transforms"] = None
|
||||
|
||||
if "incontext_weight" in cfg:
|
||||
assert isinstance(cfg["incontext_weight"], (list, tuple))
|
||||
else:
|
||||
cfg["incontext_weight"] = [1.0]
|
||||
cfg["id"] = idx
|
||||
# dataset and iterator will be built
|
||||
return cfgs
|
||||
|
||||
|
||||
def data2ids(data, tokenizer, max_length):
|
||||
text = "\n".join(
|
||||
[
|
||||
data.get("title", "").strip(),
|
||||
data.get("question", "").strip(),
|
||||
data.get("answer", "").strip(),
|
||||
data.get("abstract", "").strip(),
|
||||
data.get("text", "").strip(),
|
||||
data.get("code", "").strip(),
|
||||
]
|
||||
).strip()
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Warning: skip invalid sample without valid fields: {data}")
|
||||
yield from ()
|
||||
return
|
||||
# suppress the annoying warning from tokenizer
|
||||
ids = (
|
||||
[tokenizer.bos_token_id]
|
||||
+ tokenizer.encode(text, max_length=int(1e12), truncation=True)
|
||||
+ [tokenizer.eos_token_id]
|
||||
)
|
||||
src_ids = ids[0:-1]
|
||||
tgt_ids = ids[0:-1] # do not shift because it'll be shifted during loss calculation.
|
||||
|
||||
if len(src_ids) > max_length:
|
||||
for st in range(0, len(src_ids), max_length):
|
||||
yield src_ids[st : st + max_length], tgt_ids[st : st + max_length]
|
||||
else:
|
||||
yield src_ids, tgt_ids
|
||||
|
||||
|
||||
def cricket_data2ids(data, tokenizer, max_length: int, oversize_rule="segment", do_compact=False):
|
||||
assert oversize_rule in ("drop", "head", "segment")
|
||||
if data is None:
|
||||
yield from ()
|
||||
return
|
||||
if "output" not in data or not data["output"]:
|
||||
yield from ()
|
||||
return
|
||||
if "input" not in data:
|
||||
data["input"] = ""
|
||||
|
||||
src_ids = [tokenizer.bos_token_id]
|
||||
tgt_ids = []
|
||||
has_input = False
|
||||
is_segment_reenter = False
|
||||
|
||||
# Use incremental tokenization to avoid waiting for a long document
|
||||
MAX_CHUNK_LENGTH = max_length * 10
|
||||
for part in ("input", "output"):
|
||||
l, r = 0, min(MAX_CHUNK_LENGTH, len(data[part]))
|
||||
while l < len(data[part]):
|
||||
current_slice = data[part][l:r]
|
||||
if not current_slice:
|
||||
break
|
||||
token_ids = tokenizer.encode(current_slice, add_special_tokens=False)
|
||||
if part == "input":
|
||||
if len(token_ids) > 0:
|
||||
has_input = True
|
||||
if len(token_ids) >= max_length - 2: # input len must < max_length
|
||||
yield from ()
|
||||
return
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend([IGNORE_TGT] * len(token_ids))
|
||||
l = r
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
else:
|
||||
if len(token_ids) + len(tgt_ids) >= max_length:
|
||||
if oversize_rule == "drop":
|
||||
yield from ()
|
||||
return
|
||||
elif oversize_rule == "head":
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
elif oversize_rule == "segment":
|
||||
instruction_rest_space = max_length - 1 - len(token_ids)
|
||||
if has_input: # is instruction data
|
||||
if (
|
||||
do_compact
|
||||
and len(src_ids) >= 128 # avoid too short instruction info lost
|
||||
and instruction_rest_space / len(src_ids) > 0.8
|
||||
): # can be squeezed into max length
|
||||
inputs_len = len(src_ids)
|
||||
keep_len = instruction_rest_space // 2
|
||||
src_ids = src_ids[:keep_len] + src_ids[inputs_len - keep_len :]
|
||||
tgt_ids = [IGNORE_TGT] * (len(src_ids) - 1)
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) < max_length, f"len src_ids: {len(src_ids)}"
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else: # else use head rule
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids[:-1])
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
return
|
||||
else: # normal segment
|
||||
selected_token_ids = token_ids[: max_length - len(src_ids) + 1]
|
||||
src_ids.extend(selected_token_ids)
|
||||
tgt_ids.extend(selected_token_ids)
|
||||
assert len(src_ids) == max_length + 1, f"len src_ids: {len(src_ids)}"
|
||||
assert len(tgt_ids) == max_length, f"len tgt_ids: {len(tgt_ids)}"
|
||||
yield src_ids[:max_length], tgt_ids[:max_length]
|
||||
src_ids = src_ids[max_length:]
|
||||
tgt_ids = tgt_ids[max_length:]
|
||||
# sliding input str window
|
||||
consumed_str = tokenizer.decode(selected_token_ids)
|
||||
l += len(consumed_str)
|
||||
r = min(len(data[part]), l + MAX_CHUNK_LENGTH)
|
||||
is_segment_reenter = True
|
||||
else:
|
||||
if (is_segment_reenter and len(token_ids) > 8) or (
|
||||
not is_segment_reenter and len(token_ids) > 0
|
||||
): # is segmented LM data
|
||||
src_ids.extend(token_ids)
|
||||
tgt_ids.extend(token_ids)
|
||||
tgt_ids.append(tokenizer.eos_token_id)
|
||||
assert len(src_ids) == len(tgt_ids), f"len (src, tgt): ({len(src_ids)}, {len(tgt_ids)})"
|
||||
yield src_ids, tgt_ids
|
||||
else:
|
||||
yield from ()
|
||||
return
|
||||
|
||||
|
||||
class SegmentedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
tokenizer,
|
||||
max_length=1024,
|
||||
transform_func=None,
|
||||
nthreads=1,
|
||||
prefetch_slice=3,
|
||||
slice_size=500,
|
||||
do_compact=False,
|
||||
is_local=False,
|
||||
):
|
||||
def get_full_qualified_name(func):
|
||||
module_name = func.__module__
|
||||
qual_name = func.__qualname__
|
||||
return f"{module_name}.{qual_name}"
|
||||
|
||||
super(SegmentedDataset, self).__init__()
|
||||
self.segment = functools.partial(
|
||||
cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
|
||||
)
|
||||
self.cfg = cfg
|
||||
self.max_length = max_length
|
||||
self.nthreads = nthreads
|
||||
self.transform_func = transform_func
|
||||
self.prefetch_slice = prefetch_slice
|
||||
self.slice_size = slice_size
|
||||
self.abs_weight = cfg.get("abs_weight", None)
|
||||
self.task_name = cfg["task_name"]
|
||||
self.dataset_name = cfg["dataset_name"]
|
||||
self.oversize_rule = cfg["oversize_rule"]
|
||||
self.dataset = PrefetchDecodeDataset(path=cfg["path"], allow_repeat=cfg.get("allow_repeat", False))
|
||||
self.exhausted = False
|
||||
self.iterator = None
|
||||
|
||||
self.counter = 0
|
||||
self.allow_repeat = cfg.get("allow_repeat", True)
|
||||
self.used = set()
|
||||
self.is_local = is_local
|
||||
self.port_offset = random.randint(1, 8000) + 1000
|
||||
if is_local or bmt.rank() == 0:
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
_avg_tokens = cfg.get("ave_tokens_per_line", -1)
|
||||
_avg_tokens = cfg.get("avg_tokens", _avg_tokens)
|
||||
requests.post(f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={_avg_tokens}")
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||
def set_avg_tokens(self, avg_tokens):
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=set&length={avg_tokens}"
|
||||
response = requests.post(url)
|
||||
if response.status_code != 200:
|
||||
self.reset_port_offset()
|
||||
print(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
|
||||
raise RuntimeError(f"Failed to set avg_tokens for task {self.task_name}, request url: {url}")
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||
def update_avg_tokens_by_ema(self, length):
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}?action=update&length={length}"
|
||||
response = requests.post(url)
|
||||
if response.status_code != 200:
|
||||
self.reset_port_offset()
|
||||
print(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
|
||||
raise RuntimeError(f"Failed to update avg_tokens for task {self.task_name}, request url: {url}")
|
||||
|
||||
@property
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(5, 20))
|
||||
def avg_tokens(self):
|
||||
addr = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"]) + self.port_offset
|
||||
url = f"http://{addr}:{port}/avg_tokens/{self.task_name}"
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
self.reset_port_offset()
|
||||
print(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
|
||||
raise RuntimeError(f"Failed to get avg_tokens for task {self.task_name}, request url: {url}")
|
||||
data = response.json()
|
||||
return data["avg_tokens"]
|
||||
|
||||
def reset_port_offset(self):
|
||||
self.port_offset = random.randint(1, 8000) + 1000
|
||||
|
||||
def size(self):
|
||||
return self.dataset.size()
|
||||
|
||||
def __iter__(self):
|
||||
self.iterate()
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.exhausted = False
|
||||
if self.iterator is not None:
|
||||
self.iterator.close()
|
||||
self.iterator = None
|
||||
self.used = BitSet()
|
||||
print("Rank {}, Reset dataset:{} done.".format(bmt.rank(), self.dataset_name))
|
||||
|
||||
def transform(self, data: dict) -> dict:
|
||||
weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
|
||||
weight = weight / weight.sum()
|
||||
num_incontext = np.random.choice(weight.shape[0], p=weight)
|
||||
return self.transform_func(data, num_incontext, random.Random())
|
||||
|
||||
def segment_iterate(self, sample_iter):
|
||||
for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
|
||||
for src_ids, tgt_ids in self.segment(self.transform(data)):
|
||||
self.update_avg_tokens_by_ema(len(src_ids)) # 0 for input ids
|
||||
yield src_ids, tgt_ids, index
|
||||
|
||||
def iterate(self):
|
||||
# make the dataset itself an iterator
|
||||
sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
|
||||
self.iterator = self.segment_iterate(sample_iter)
|
||||
|
||||
def __next__(self):
|
||||
# advance the task iterator
|
||||
if self.iterator is None:
|
||||
self.iterate()
|
||||
try:
|
||||
return next(self.iterator)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if state_dict.get("exhausted", False):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
used = state_dict.get("used", BitSet())
|
||||
if len(used) == len(self.dataset):
|
||||
self.exhausted = True
|
||||
self.used = BitSet()
|
||||
else:
|
||||
self.exhausted = False
|
||||
self.used = used
|
||||
_avg_tokens = state_dict.get("ave_tokens", -1)
|
||||
_avg_tokens = state_dict.get("avg_tokens", _avg_tokens)
|
||||
if self.avg_tokens == -1 or self.avg_tokens < _avg_tokens:
|
||||
self.set_avg_tokens(_avg_tokens)
|
||||
|
||||
def state_dict(self):
|
||||
if len(self.used) == len(self.dataset):
|
||||
return dict(exhausted=True, used=BitSet(), avg_tokens=self.avg_tokens)
|
||||
else:
|
||||
return dict(exhausted=False, used=self.used, avg_tokens=self.avg_tokens)
|
||||
|
||||
def update_state(self, indice):
|
||||
self.used.update(indice)
|
||||
|
||||
|
||||
class MixedIndexedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str,
|
||||
cfg_json_str,
|
||||
tokenizer,
|
||||
max_length,
|
||||
weight_by_size: bool = True,
|
||||
nthreads=5,
|
||||
prefetch_slice=100,
|
||||
parallel_loading=False,
|
||||
vdc_sampling=False,
|
||||
update_weights_frequency=1,
|
||||
seed=42,
|
||||
):
|
||||
if bmt.rank() == 0:
|
||||
port = int(os.environ["MASTER_PORT"]) + 2188
|
||||
self.flask_ps_proc = mp.Process(target=flask_ps.run, kwargs={"host": "0.0.0.0", "port": port})
|
||||
self.flask_ps_proc.start()
|
||||
super(MixedIndexedDataset, self).__init__()
|
||||
self.set_seed(seed + bmt.rank())
|
||||
self.weight_by_size = weight_by_size
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.bos_token_id = self.tokenizer.bos_token_id
|
||||
self.path2transform = dict()
|
||||
self.task_dict = OrderedDict()
|
||||
self.nthreads = nthreads
|
||||
self.prefetch_slice = prefetch_slice
|
||||
# useful for indexing
|
||||
self.tasks = []
|
||||
self.names = []
|
||||
# ending of iteration
|
||||
self.remain = 0
|
||||
self.max_length = max_length
|
||||
self.vdc_sampling = vdc_sampling
|
||||
if self.vdc_sampling:
|
||||
self._vdc_values = [van_der_corput(i) for i in range(100000)]
|
||||
self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)
|
||||
|
||||
self.update_weights_frequency = update_weights_frequency
|
||||
|
||||
self.path2transform = dict()
|
||||
|
||||
cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
|
||||
_sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
|
||||
_weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
|
||||
bmt.print_rank("Absolute Weight of DataSet {}".format(_weights))
|
||||
|
||||
if parallel_loading:
|
||||
self.parallel_load(cfgs, max_workers=None)
|
||||
else:
|
||||
self.sequential_load(cfgs)
|
||||
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def set_seed(self, seed):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def load_task(self, cfg):
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
return task
|
||||
|
||||
def sequential_load(self, cfgs):
|
||||
self.cfgs = cfgs
|
||||
for cfg in cfgs:
|
||||
# python3.7 and later preserves insertion order to dictionary
|
||||
logger.info(f"Loading {cfg['path']}")
|
||||
|
||||
transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
|
||||
task = SegmentedDataset(
|
||||
cfg,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
transform_func=transform_func,
|
||||
nthreads=self.nthreads,
|
||||
prefetch_slice=self.prefetch_slice,
|
||||
do_compact=cfg.get("do_compact", False), # dataset level do_compact
|
||||
)
|
||||
self.task_dict[task.task_name] = task
|
||||
self.tasks.append(task)
|
||||
self.names.append(task.task_name)
|
||||
self.remain += 1
|
||||
self.weights = None
|
||||
self.update_weights()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
missing_keys = []
|
||||
for name, task in self.task_dict.items():
|
||||
if name in state_dict:
|
||||
task.load_state_dict(state_dict[name])
|
||||
else:
|
||||
missing_keys.append(name)
|
||||
self.update_weights()
|
||||
return missing_keys
|
||||
|
||||
def save_state_dict(self, path):
|
||||
state_dict = {}
|
||||
for name, task in self.task_dict.items():
|
||||
_state_dict = task.state_dict()
|
||||
if isinstance(_state_dict["used"], BitSet):
|
||||
bitset = _state_dict["used"]
|
||||
_file_name = bitset.save(path)
|
||||
_state_dict["used"] = _file_name
|
||||
state_dict[name] = _state_dict
|
||||
else:
|
||||
state_dict[name] = task.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
logger.info("Dataset state saved")
|
||||
|
||||
def update_states(self, task_ids, indice):
|
||||
is_dict = isinstance(indice, dict)
|
||||
uniq = torch.unique(task_ids)
|
||||
for idx in uniq:
|
||||
idx = idx.item()
|
||||
indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
|
||||
self.tasks[idx].update_state(indexes)
|
||||
|
||||
def get_transform_func(self, module_name: str, transform_script_path):
|
||||
if transform_script_path is None:
|
||||
# allow null transform
|
||||
return lambda data, num_incontext, rand: data
|
||||
if "/" in module_name:
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name.split("/")[-1])
|
||||
else:
|
||||
module_name = "fm9g_live.transforms.{}".format(module_name)
|
||||
if transform_script_path not in self.path2transform:
|
||||
# loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)
|
||||
# spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
spec = importlib.util.spec_from_file_location(module_name, transform_script_path)
|
||||
if spec is None:
|
||||
raise RuntimeError("Spec is none! {}".format(module_name))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
self.path2transform[transform_script_path] = {
|
||||
"module": mod,
|
||||
"last_mtime": 0,
|
||||
}
|
||||
transform_script_info = self.path2transform[transform_script_path]
|
||||
curr_mtime = float(os.path.getmtime(transform_script_path))
|
||||
if curr_mtime > transform_script_info["last_mtime"]:
|
||||
transform_script_info["last_mtime"] = curr_mtime
|
||||
# load module
|
||||
spec.loader.exec_module(transform_script_info["module"])
|
||||
transform_func = getattr(transform_script_info["module"], "transform", None)
|
||||
if transform_func is None:
|
||||
raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
|
||||
return transform_func
|
||||
|
||||
def update_weights(self):
|
||||
task0 = self.tasks[0]
|
||||
if task0.abs_weight is not None: # 这一份config是指定绝对比例的
|
||||
weights = []
|
||||
for task in self.tasks:
|
||||
if task.exhausted:
|
||||
weights.append(0)
|
||||
else:
|
||||
if task.avg_tokens == -1:
|
||||
weights.append(task.abs_weight / self.max_length)
|
||||
else:
|
||||
weights.append(task.abs_weight / task.avg_tokens)
|
||||
weights = np.array(weights)
|
||||
else:
|
||||
weights = np.array([0 if task.exhausted else task.weight for task in self.tasks])
|
||||
if self.weight_by_size:
|
||||
sizes = np.array([task.size() for task in self.tasks], dtype=np.float32)
|
||||
weights *= sizes
|
||||
self.weights = weights / weights.sum()
|
||||
|
||||
def __iter__(self):
|
||||
for task in self.tasks:
|
||||
task.iterate()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
step = 1
|
||||
while True:
|
||||
if self.remain == 0:
|
||||
print("Rank {}, All task exhaust !!!!".format(bmt.rank()))
|
||||
raise StopIteration
|
||||
if self.vdc_sampling:
|
||||
idx = next(self.vdc_gen)(self.weights)
|
||||
else:
|
||||
idx = np.random.choice(len(self.weights), p=self.weights)
|
||||
|
||||
data = next(self.tasks[idx])
|
||||
if step % self.update_weights_frequency == 0:
|
||||
self.update_weights()
|
||||
if data is None:
|
||||
if self.tasks[idx].allow_repeat:
|
||||
# _runtime_ave = self.tasks[idx].avg_tokens
|
||||
print("Rank {}, dataset {} exhaust, repeat...".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
# self.tasks[idx] = SegmentedDataset(
|
||||
# self.tasks[idx].cfg, self.tokenizer, self.max_length, transform_func=self.tasks[idx].transform_func, nthreads=self.nthreads, prefetch_slice=self.prefetch_slice
|
||||
# )
|
||||
# self.tasks[idx].avg_tokens_update(_runtime_ave)
|
||||
self.tasks[idx].reset()
|
||||
else:
|
||||
print("Rank {}, dataset {} exhaust, not repeat.".format(bmt.rank(), self.tasks[idx].dataset_name))
|
||||
self.tasks[idx].exhaust = True
|
||||
self.remain -= 1
|
||||
continue
|
||||
step += 1
|
||||
return dict(
|
||||
task_id=idx,
|
||||
input=data[0],
|
||||
target=data[1],
|
||||
index=data[2],
|
||||
is_long=self.tasks[idx].cfg.get("is_long", False),
|
||||
)
|
||||
|
||||
|
||||
class UnpadBatchedMixedDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self, mixed_dataset, batch_size, max_length, pose_prob=0.0, pose_scaling_factor=1.0, compact=False):
|
||||
self.max_total_length = batch_size * max_length
|
||||
self.batch_size = 1
|
||||
# setting compact=True concats segments orignated from the same input
|
||||
# into a long sequence. the relative order of segments should be preserved
|
||||
# in mixed_dataset, e.g.,
|
||||
# - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
|
||||
# - not_ok: task1_seg1, task1_seg3, task2_seg1, task1_seg2
|
||||
self.compact = compact
|
||||
|
||||
self.total_length = 0
|
||||
self.task2seqs = defaultdict(list)
|
||||
self.mixed_dataset = mixed_dataset
|
||||
self._max_length = max_length
|
||||
self._pose_prob = pose_prob
|
||||
self._pose_scaling_factor = pose_scaling_factor
|
||||
if self._pose_prob > 0.0:
|
||||
self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
|
||||
else:
|
||||
self._scaled_max_length = max_length
|
||||
|
||||
def put(self, sample):
|
||||
self.total_length += len(sample["target"])
|
||||
task_id = sample["task_id"]
|
||||
if self.compact and self.task2seqs[task_id]:
|
||||
last = self.task2seqs[task_id][-1]
|
||||
if last["target"][-1] != self.mixed_dataset.eos_token_id:
|
||||
# concatenate sequantial segments for longer context modeling: why not?
|
||||
last["input"].extend(sample["input"])
|
||||
last["target"].extend(sample["target"])
|
||||
return
|
||||
self.task2seqs[task_id].append(sample)
|
||||
|
||||
def _pose_preprocess(
|
||||
self,
|
||||
input_ids: NDArray[np.int32],
|
||||
) -> NDArray[np.int32]:
|
||||
"""[PoSE](https://arxiv.org/abs/2309.10400v2)
|
||||
GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
|
||||
"""
|
||||
len_chunk = min(len(input_ids), self._max_length)
|
||||
len_input = len(input_ids)
|
||||
# Chunk input randomly to fit max_length if needed
|
||||
lt1 = 0
|
||||
rt1 = random.randint(0, (len_chunk + 1) // 2) # Fist chunk only contains 1/2 tokens at most
|
||||
rt2 = random.randint(lt1 + len_chunk, len_input) # Second chunk can randomly shift when not filled max_length
|
||||
lt2 = rt2 - (len_chunk - (rt1 - lt1)) # assure all tokens are used
|
||||
chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
|
||||
# Generate PoSE position ids
|
||||
position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
|
||||
len_position_ids = len(position_ids)
|
||||
lt = 0
|
||||
rt = random.randint(lt, self._scaled_max_length - len_position_ids)
|
||||
position_ids[: rt1 - lt1] += lt
|
||||
position_ids[rt1 - lt1 :] += rt
|
||||
return position_ids
|
||||
|
||||
def pop(self):
|
||||
indexes = defaultdict(list)
|
||||
lengths = []
|
||||
|
||||
inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
|
||||
task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
|
||||
position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
|
||||
|
||||
span_begin = 0
|
||||
for samples in self.task2seqs.values():
|
||||
while samples:
|
||||
sample = samples.pop()
|
||||
span_end = span_begin + len(sample["input"])
|
||||
inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
|
||||
targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
|
||||
task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
|
||||
if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
|
||||
_span_position_ids = self._pose_preprocess(sample["input"])
|
||||
else:
|
||||
_span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
|
||||
position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
|
||||
# position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
|
||||
lengths.append(len(sample["target"]))
|
||||
indexes[int(sample["task_id"])].append(sample["index"])
|
||||
self.total_length -= len(sample["target"])
|
||||
span_begin = span_end
|
||||
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
|
||||
dim=0,
|
||||
).int()
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"targets": targets,
|
||||
"task_ids": task_ids,
|
||||
"indexes": indexes,
|
||||
# adhere to flash attention interface
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": int(torch.max(cu_seqlens[1:] - cu_seqlens[:-1])),
|
||||
"lengths": torch.tensor(sum(lengths)).int(),
|
||||
"task_names": self.mixed_dataset.names,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return batch
|
||||
|
||||
def will_be_full(self, sample):
|
||||
return self.total_length + len(sample["target"]) > self.max_total_length
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self.mixed_dataset:
|
||||
if self.will_be_full(sample):
|
||||
yield self.pop()
|
||||
self.put(sample)
|
||||
|
||||
|
||||
class CudaPrefetcher(Iterable):
|
||||
"""
|
||||
Wrap around a batch iterator for asynchornously copying data to gpu to shield memcpy latency.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, tp_size=1, tp_rank=0):
|
||||
self.loader = iter(loader)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = tp_rank
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
if self.tp_size > 1:
|
||||
if self.tp_rank == 0:
|
||||
data = next(self.loader)
|
||||
print("Rank {}, Preload data done.".format(bmt.rank()))
|
||||
d = {}
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "wb") as fb:
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], torch.Tensor):
|
||||
np_cur_data = data[key].cpu().numpy()
|
||||
bs = np_cur_data.tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["TORCH", str(np_cur_data.dtype), len(bs)] + list(np_cur_data.shape)
|
||||
elif isinstance(data[key], np.ndarray):
|
||||
bs = data[key].tobytes()
|
||||
fb.write(bs)
|
||||
d[key] = ["NUMPY", str(data[key].dtype), len(bs)] + list(data[key].shape)
|
||||
else:
|
||||
d[key] = data[key]
|
||||
try:
|
||||
_ = json.dumps(d)
|
||||
except TypeError:
|
||||
print(d)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "w") as f:
|
||||
json.dump(d, f)
|
||||
bmt.synchronize()
|
||||
if self.tp_rank != 0:
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.json", "r") as f:
|
||||
data = json.load(f)
|
||||
with open(f"/dev/shm/BMT_TP_{bmt.config['topology'].tp_idx}.bin", "rb") as fb:
|
||||
bs = fb.read()
|
||||
offset = 0
|
||||
for key in data.keys():
|
||||
if isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "NUMPY":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = np.frombuffer(bs[offset:nw_offset], dtype=data[key][1]).reshape(
|
||||
data[key][3:]
|
||||
)
|
||||
offset = nw_offset
|
||||
elif isinstance(data[key], list) and len(data[key]) > 1 and data[key][0] == "TORCH":
|
||||
nw_offset = offset + data[key][2]
|
||||
data[key] = torch.from_numpy(
|
||||
np.frombuffer(bs[offset:nw_offset], dtype=data[key][1])
|
||||
.reshape(data[key][3:])
|
||||
.copy()
|
||||
)
|
||||
offset = nw_offset
|
||||
self.data = data
|
||||
else:
|
||||
self.data = next(self.loader)
|
||||
except StopIteration:
|
||||
self.data = None
|
||||
return
|
||||
with torch.cuda.stream(self.stream):
|
||||
for key in self.data.keys():
|
||||
if isinstance(self.data[key], torch.Tensor):
|
||||
self.data[key] = self.data[key].cuda(non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
data = copy.deepcopy(self.data)
|
||||
self.preload()
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
|
@ -1,79 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import bmtrain as bmt
|
||||
|
||||
def _linear_backward(grad_output, x, weight, bias):
|
||||
grad_x = grad_weight = grad_bias = None
|
||||
if x.requires_grad:
|
||||
grad_x = grad_output.matmul(weight)
|
||||
if weight.requires_grad:
|
||||
grad_weight = grad_output.reshape(-1,
|
||||
grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
|
||||
if bias is not None and bias.requires_grad:
|
||||
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
|
||||
return grad_x, grad_weight, grad_bias
|
||||
|
||||
class OpAttnPipeSP(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q_w, k_w, v_w, q_b, w_b, v_b, x, cache_kv, cache_kv_inp, cu_seqlens_q, cu_seqlens_k, max_seqlen):
|
||||
ctx.save_for_backward(x, q_w, k_w, v_w, q_b, w_b, v_b)
|
||||
if cache_kv.numel() = 0:
|
||||
q = F.linear(x, q_w, q_b)
|
||||
k = F.linear(x, k_w, w_b)
|
||||
v = F.linear(x, v_w, v_b)
|
||||
else:
|
||||
q = F.linear(x, q_w, q_b)
|
||||
k = F.linear(x, k_w, w_b)
|
||||
v = F.linear(x, v_w, v_b)
|
||||
k = torch.cat([cache_kv[0], k], dim=1)
|
||||
v = torch.cat([cache_kv[1], v], dim=1)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen, max_seqlen, 0, causal=True, window_size=(-1,-1), alibi_slopes=None, deterministic=False, return_attn_probs=False
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
||||
)
|
||||
ctx.max_seqlen_q = max_seqlen
|
||||
ctx.max_seqlen_k = max_seqlen
|
||||
|
||||
|
||||
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_varlen_backward(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
ctx.max_seqlen_q,
|
||||
ctx.max_seqlen_k,
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
False,
|
||||
(-1,-1),
|
||||
None,
|
||||
False,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., : dout.shape[-1]]
|
||||
dv = dv[..., : dout.shape[-1]]
|
||||
|
||||
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
|
||||
d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b)
|
||||
d_xk, d_wk, d_bk = _linear_backward(dk, x, k_w, k_b)
|
||||
|
||||
|
||||
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
|
|
@ -1,332 +0,0 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from .flash_triton import FlashAttnFunc
|
||||
except:
|
||||
FlashAttnFunc = None
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from .linear import ColumnParallelLinear
|
||||
from .linear import Linear
|
||||
from .position_embedding import apply_chatglm_rotary_pos_emb
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
#try:
|
||||
# from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
|
||||
# from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
|
||||
# from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
#except:
|
||||
# flash_attn_varlen_func = None
|
||||
|
||||
try:
|
||||
from flash_attn.bert_padding import pad_input
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
except:
|
||||
pad_input = None
|
||||
unpad_input = None
|
||||
|
||||
|
||||
class OpFlash(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, self, record, q, k, v, cu_seqlens, max_seqlen, dropout_p, causal):
|
||||
ctx.self = self
|
||||
ctx.cu_seqlens = cu_seqlens
|
||||
ctx.max_length = max_seqlen
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.causal = causal
|
||||
ctx.softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if not record and "out" in self._layer_dict:
|
||||
out = self._layer_dict.pop("out")
|
||||
softmax_lse = self._layer_dict.pop("softmax_lse")
|
||||
rng_state = self._layer_dict.pop("rng_state")
|
||||
else:
|
||||
out, _, _, _, _, softmax_lse, _, rng_state = _flash_attn_varlen_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
dropout_p,
|
||||
ctx.softmax_scale,
|
||||
causal=causal,
|
||||
window_size=(-1, -1),
|
||||
alibi_slopes=None,
|
||||
return_softmax=False,
|
||||
)
|
||||
if record:
|
||||
self._layer_dict["out"] = out
|
||||
self._layer_dict["softmax_lse"] = softmax_lse
|
||||
self._layer_dict["rng_state"] = rng_state
|
||||
|
||||
ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_varlen_backward(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
ctx.cu_seqlens,
|
||||
ctx.cu_seqlens,
|
||||
ctx.max_length,
|
||||
ctx.max_length,
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
(-1,-1),
|
||||
None,
|
||||
False,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
return None, None, dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
class Attention(bmt.DistributedModule):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
dim_head: int,
|
||||
dtype: torch.dtype = torch.half,
|
||||
dropout_p: Optional[float] = None,
|
||||
scale: bool = True,
|
||||
add_qkv_bias: bool = False,
|
||||
use_flash_attn: bool = False,
|
||||
tp: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = dim_model
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_groups = num_heads // num_kv_heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.project_q = Linear(
|
||||
self.dim_model,
|
||||
self.num_heads * self.dim_head,
|
||||
bias=add_qkv_bias,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
tp=tp,
|
||||
)
|
||||
self.project_k = Linear(
|
||||
self.dim_model,
|
||||
self.num_kv_heads * self.dim_head,
|
||||
bias=add_qkv_bias,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
tp=tp,
|
||||
)
|
||||
self.project_v = Linear(
|
||||
self.dim_model,
|
||||
self.num_kv_heads * self.dim_head,
|
||||
bias=add_qkv_bias,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
tp=tp,
|
||||
)
|
||||
|
||||
self.attention_out = Linear(
|
||||
self.num_heads * self.dim_head,
|
||||
self.dim_model,
|
||||
dtype=dtype,
|
||||
scale=scale,
|
||||
tp=tp * 2,
|
||||
)
|
||||
|
||||
self.softmax = torch.nn.Softmax(dim=-1)
|
||||
|
||||
if dropout_p is not None:
|
||||
self.dropout = torch.nn.Dropout(p=dropout_p)
|
||||
self.dropout_p = dropout_p
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self._layer_dict = {}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_q: torch.Tensor,
|
||||
hidden_kv: torch.Tensor,
|
||||
attention_mask: torch.BoolTensor,
|
||||
position_bias: torch.Tensor,
|
||||
use_cache: bool = False,
|
||||
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
pos_bias_type: Optional[str] = "relative",
|
||||
length_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask_bias: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: int = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""This model inherits from bmt.DistributedModule.
|
||||
Args:
|
||||
hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix.
|
||||
hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding.
|
||||
attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices.
|
||||
position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`.
|
||||
Return:
|
||||
out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output.
|
||||
""" # noqa: E501
|
||||
|
||||
len_q = hidden_q.size(1)
|
||||
len_k = hidden_kv.size(1)
|
||||
|
||||
if isinstance(self.project_q, ColumnParallelLinear):
|
||||
assert hidden_q.data_ptr() == hidden_kv.data_ptr()
|
||||
if self.project_q.scale and self.project_q.scale_before:
|
||||
hidden_q = hidden_q / math.sqrt(self.project_q.dim_in)
|
||||
hidden_q = bmt.nn.OpParallelLinear.apply(
|
||||
hidden_q,
|
||||
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
|
||||
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0)
|
||||
if self.project_q.bias is not None
|
||||
else None,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
if self.project_q.scale and not self.project_q.scale_before:
|
||||
hidden_q = hidden_q / math.sqrt(self.project_q.dim_in)
|
||||
|
||||
block_size = hidden_q.shape[-1] // (self.head_groups + 1 + 1)
|
||||
h_q = hidden_q[..., : block_size * self.head_groups]
|
||||
h_k = hidden_q[..., block_size * self.head_groups : block_size * (self.head_groups + 1)]
|
||||
h_v = hidden_q[..., block_size * (self.head_groups + 1) :]
|
||||
else:
|
||||
h_q = self.project_q(hidden_q)
|
||||
h_k = self.project_k(hidden_kv)
|
||||
h_v = self.project_v(hidden_kv)
|
||||
|
||||
batch_size = h_q.size(0)
|
||||
|
||||
if not self.use_flash_attn:
|
||||
h_q = h_q / math.sqrt(math.sqrt(self.dim_head))
|
||||
h_k = h_k / math.sqrt(math.sqrt(self.dim_head))
|
||||
|
||||
h_q = h_q.view(batch_size, len_q, -1, self.dim_head).permute(0, 2, 1, 3)
|
||||
h_k = h_k.view(batch_size, len_k, -1, self.dim_head).permute(0, 2, 1, 3)
|
||||
h_v = h_v.view(batch_size, len_k, -1, self.dim_head).permute(0, 2, 1, 3)
|
||||
|
||||
if pos_bias_type == "rotary":
|
||||
# b h s d
|
||||
h_q, h_k = position_bias(h_q, h_k, -2, offset=past_kv[0].size(-2) if past_kv is not None else 0)
|
||||
elif pos_bias_type == "chatglm_rotary":
|
||||
h_q = apply_chatglm_rotary_pos_emb(h_q, position_bias)
|
||||
h_k = apply_chatglm_rotary_pos_emb(h_k, position_bias)
|
||||
|
||||
if past_kv is not None:
|
||||
h_k = torch.cat([past_kv[0], h_k], dim=-2)
|
||||
h_v = torch.cat([past_kv[1], h_v], dim=-2)
|
||||
len_k = h_k.size(-2)
|
||||
|
||||
# (b, n_h, len_q, d_h) @ (b, n_h, d_h, len_k) -> (b, n_h, len_q, len_k)
|
||||
# (b, n_kv_h, n_h_groups*len_q, d_h) @ (b, n_kv_h, d_h, len_k) -> (b, n_kv_h, n_h_groups*len_q, len_k) -> (b, n_h, len_q, len_k)
|
||||
if self.head_groups == 1:
|
||||
score = torch.matmul(h_q, h_k.transpose(-1, -2)) # / math.sqrt(self.dim_head) moved to line 75~76
|
||||
else:
|
||||
score = torch.matmul(
|
||||
h_q.reshape(batch_size, -1, self.head_groups * len_q, self.dim_head),
|
||||
h_k.transpose(-1, -2),
|
||||
).view(
|
||||
batch_size, -1, len_q, len_k
|
||||
) # / math.sqrt(self.dim_head) moved to line 75~76
|
||||
if pos_bias_type == "relative":
|
||||
if len_q == 1: # inference with cache
|
||||
if len(position_bias.size()) == 4:
|
||||
position_bias = position_bias[:, :, -1:, :]
|
||||
else:
|
||||
position_bias = position_bias[:, -1:, :]
|
||||
score = score + position_bias
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
||||
)
|
||||
|
||||
score = self.softmax(score)
|
||||
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(batch_size, 1, len_q, len_k) == False,
|
||||
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
||||
)
|
||||
|
||||
if self.dropout is not None:
|
||||
score = self.dropout(score)
|
||||
|
||||
# (b, n_h, len_q, len_k) @ (b, n_h, len_k, d_h) -> (b, n_h, len_q, d_h)
|
||||
# (b, n_kv_h, n_h_groups*len_q, len_k) @ (b, n_kv_h, len_k, d_h) -> (b, n_kv_h, n_h_groups*len_q, d_h) -> (b, n_h, len_q, d_h)
|
||||
score = torch.matmul(score.view(batch_size, -1, self.head_groups * len_q, len_k), h_v).view(
|
||||
batch_size, -1, len_q, self.dim_head
|
||||
)
|
||||
|
||||
score = score.view(batch_size, -1, len_q, self.dim_head).permute(0, 2, 1, 3)
|
||||
score = score.contiguous().view(batch_size, len_q, -1)
|
||||
|
||||
else:
|
||||
if attention_mask_bias is not None:
|
||||
assert pos_bias_type == "rotary"
|
||||
h_q = h_q.view(batch_size, len_q, -1, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_k = h_k.view(batch_size, len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_v = h_v.view(batch_size, len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_q, h_k = position_bias(h_q, h_k, -3)
|
||||
score = FlashAttnFunc.apply(h_q, h_k, h_v, attention_mask_bias, False, None)
|
||||
else:
|
||||
if pos_bias_type == "chatglm_rotary":
|
||||
raise NotImplemented("No FlashAttn version for ChatGLM at present!")
|
||||
h_q = h_q.view(batch_size * len_q, -1, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_k = h_k.view(batch_size * len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_v = h_v.view(batch_size * len_k, -1, self.dim_head) # .permute(0, 2, 1, 3)
|
||||
h_q, h_k = position_bias(
|
||||
h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids
|
||||
)
|
||||
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
||||
print("h_q: ", h_q)
|
||||
print("cu_seqlens: ", cu_seqlens)
|
||||
print("max_seqlen: ", max_seqlen)
|
||||
score = flash_attn_varlen_func(
|
||||
h_q,
|
||||
h_k,
|
||||
h_v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
self.dropout_p,
|
||||
causal=True,
|
||||
deterministic=True,
|
||||
)
|
||||
print(type(h_q), type(cu_seqlens), type(max_seqlen), type(self.dropout_p))
|
||||
# Rongqiao change
|
||||
#score = OpFlash.apply(
|
||||
# self, not torch.is_grad_enabled(), h_q, h_k, h_v, cu_seqlens, max_seqlen, self.dropout_p, True
|
||||
#)
|
||||
|
||||
|
||||
score = score.view(batch_size, len_q, -1)
|
||||
|
||||
score = self.attention_out(score)
|
||||
|
||||
if use_cache:
|
||||
return score, (h_k, h_v)
|
||||
else:
|
||||
return score
|
|
@ -1,265 +0,0 @@
|
|||
import inspect
|
||||
import math
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def Linear(*args, **kwargs):
|
||||
tp = kwargs.pop("tp", 0)
|
||||
if tp == 0:
|
||||
return NormalLinear(*args, **kwargs)
|
||||
if tp == 1:
|
||||
return ColumnParallelLinear(*args, **kwargs)
|
||||
if tp == 2:
|
||||
return RowParallelLinear(*args, **kwargs)
|
||||
|
||||
|
||||
class OpLastLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, self, record, x, weight, bias=None):
|
||||
ctx.self = self
|
||||
if not record and "r" in self._layer_dict:
|
||||
ctx.save_for_backward(x, weight, bias)
|
||||
self._layer_dict.pop("r")
|
||||
return torch.zeros((*x.shape[:-1], self.out_features), device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
ctx.save_for_backward(x, weight, bias)
|
||||
if record:
|
||||
self._layer_dict["r"] = True
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, weight, bias = ctx.saved_tensors
|
||||
grad_x = grad_weight = grad_bias = None
|
||||
if x.requires_grad:
|
||||
grad_x = grad_output.matmul(weight)
|
||||
if weight.requires_grad:
|
||||
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1]))
|
||||
if bias is not None and bias.requires_grad:
|
||||
grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
|
||||
return None, None, grad_x, grad_weight, grad_bias
|
||||
|
||||
|
||||
class LastLinear(bmt.DistributedModule):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_out: int,
|
||||
bias: bool = False,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
scale: bool = True,
|
||||
scale_before: bool = False,
|
||||
tp: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_in = self.in_features = dim_in
|
||||
self.dim_out = self.out_features = dim_out
|
||||
self.scale = scale
|
||||
self.scale_before = scale_before
|
||||
|
||||
if not scale:
|
||||
init_std = 1 / ((dim_in + dim_out) ** 0.5)
|
||||
|
||||
self.weight = bmt.DistributedParameter(
|
||||
torch.empty((dim_out, dim_in), dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
)
|
||||
self.bias = (
|
||||
bmt.DistributedParameter(
|
||||
torch.empty(dim_out, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
)
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
self._layer_dict = {}
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
||||
""" # noqa: E501
|
||||
if self.scale and self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
x = OpLastLinear.apply(self, not torch.is_grad_enabled(), x, self.weight, self.bias)
|
||||
if self.scale and not self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
return x
|
||||
|
||||
|
||||
class NormalLinear(bmt.DistributedModule):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_out: int,
|
||||
bias: bool = False,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
scale: bool = True,
|
||||
scale_before: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_in = self.in_features = dim_in
|
||||
self.dim_out = self.out_features = dim_out
|
||||
self.scale = scale
|
||||
self.scale_before = scale_before
|
||||
if not scale:
|
||||
init_std = 1 / ((dim_in + dim_out) ** 0.5)
|
||||
|
||||
self.weight = bmt.DistributedParameter(
|
||||
torch.empty((dim_out, dim_in), dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
)
|
||||
self.bias = (
|
||||
bmt.DistributedParameter(
|
||||
torch.empty(dim_out, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
)
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
||||
""" # noqa: E501
|
||||
if self.scale and self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
if "tp_size" in inspect.signature(bmt.init_distributed).parameters:
|
||||
x = bmt.nn.OpLinear.apply(x, self.weight, self.bias)
|
||||
else:
|
||||
x = F.linear(x, self.weight, self.bias)
|
||||
if self.scale and not self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
return x
|
||||
|
||||
|
||||
class ColumnParallelLinear(bmt.DistributedModule):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_out: int,
|
||||
bias: bool = False,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
scale: bool = True,
|
||||
scale_before: bool = False,
|
||||
gather_output=False,
|
||||
gather_input=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim_out % bmt.config["tp_size"] == 0
|
||||
if not scale:
|
||||
init_std = 1 / ((dim_in + dim_out) ** 0.5)
|
||||
dim_out = dim_out // bmt.config["tp_size"]
|
||||
self.dim_in = self.in_features = dim_in
|
||||
self.dim_out = self.out_features = dim_out
|
||||
self.scale = scale
|
||||
self.scale_before = scale_before
|
||||
self.gather_input = gather_input
|
||||
self.gather_output = gather_output
|
||||
|
||||
self.weight = bmt.DistributedParameter(
|
||||
torch.empty((dim_out, dim_in), dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
tp_split_dim=0,
|
||||
tp_mode=True,
|
||||
)
|
||||
self.bias = (
|
||||
bmt.DistributedParameter(
|
||||
torch.empty(dim_out, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
tp_split_dim=0,
|
||||
tp_mode=True,
|
||||
)
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
||||
""" # noqa: E501
|
||||
if self.scale and self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
x = bmt.nn.OpParallelLinear.apply(x, self.weight, self.bias, self.gather_input, self.gather_output, False, None)
|
||||
if self.scale and not self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
return x
|
||||
|
||||
|
||||
class RowParallelLinear(bmt.DistributedModule):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
dim_out: int,
|
||||
bias: bool = False,
|
||||
dtype: torch.dtype = torch.half,
|
||||
init_mean: float = 0.0,
|
||||
init_std: float = 1,
|
||||
scale: bool = True,
|
||||
scale_before: bool = False,
|
||||
split_input=False,
|
||||
all_reduce_output=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim_in % bmt.config["tp_size"] == 0
|
||||
if not scale:
|
||||
init_std = 1 / ((dim_in + dim_out) ** 0.5)
|
||||
dim_in = dim_in // bmt.config["tp_size"]
|
||||
self.dim_in = self.in_features = dim_in
|
||||
self.dim_out = self.out_features = dim_out
|
||||
self.scale = scale
|
||||
self.scale_before = scale_before
|
||||
self.split_input = split_input
|
||||
self.all_reduce_output = all_reduce_output
|
||||
|
||||
self.weight = bmt.DistributedParameter(
|
||||
torch.empty((dim_out, dim_in), dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
tp_split_dim=1,
|
||||
tp_mode=True,
|
||||
)
|
||||
self.bias = (
|
||||
bmt.DistributedParameter(
|
||||
torch.empty(dim_out, dtype=dtype),
|
||||
init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std),
|
||||
tp_split_dim=-1,
|
||||
tp_mode=True,
|
||||
)
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer
|
||||
Returns:
|
||||
:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y.
|
||||
""" # noqa: E501
|
||||
if self.scale and self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
x = bmt.nn.OpParallelLinear.apply(
|
||||
x, self.weight, None, self.split_input, False, self.split_input, 1 if self.all_reduce_output else 2
|
||||
)
|
||||
if self.bias is not None:
|
||||
x = x + self.bias
|
||||
if self.scale and not self.scale_before:
|
||||
x = x / math.sqrt(self.dim_in)
|
||||
return x
|
|
@ -1,22 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class EMAValue(object):
|
||||
def __init__(self, init_value: Optional[float] = None, decay_factor: float = 0.999) -> None:
|
||||
super().__init__()
|
||||
self._value = init_value
|
||||
self._decay_factor = decay_factor
|
||||
|
||||
@property
|
||||
def value(self) -> Optional[float]:
|
||||
return self._value
|
||||
|
||||
def update(self, value: float) -> None:
|
||||
if self._value is None:
|
||||
self._value = value
|
||||
else:
|
||||
self._value = self._decay_factor * self._value + (1 - self._decay_factor) * value
|
||||
|
||||
def update_with_return(self, value: float) -> Optional[float]:
|
||||
self.update(value)
|
||||
return self._value
|
|
@ -1 +0,0 @@
|
|||
from .fm9g import FM9GTokenizer
|
|
@ -1,247 +0,0 @@
|
|||
import bmtrain as bmt
|
||||
import math
|
||||
|
||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||
r"""
|
||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, start_lr, warmup_iter, end_iter, num_iter=0, lr_end_restart=0, resume_no_optimze=0
|
||||
) -> None:
|
||||
self.warmup_iter = warmup_iter
|
||||
self.end_iter = end_iter
|
||||
self.optimizer = optimizer
|
||||
self.num_iter = num_iter
|
||||
self._current_lr = None
|
||||
self._start_lr = start_lr
|
||||
self.start_lr = []
|
||||
self.lr_end_restart = lr_end_restart
|
||||
self.resume_step = num_iter
|
||||
self.resume_no_optimze = resume_no_optimze
|
||||
for group in self.optimizer.param_groups:
|
||||
self.start_lr.append(group["lr"])
|
||||
|
||||
self.step(self.num_iter)
|
||||
|
||||
def get_lr_warmup(self, num_iter, base_lr) -> float:
|
||||
return base_lr * num_iter / self.warmup_iter
|
||||
|
||||
def get_lr_decay(self, num_iter, base_lr) -> float:
|
||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||
if progress > 1:
|
||||
if self.lr_end_restart == 0:
|
||||
progress = 1
|
||||
elif self.lr_end_restart == 1:
|
||||
progress = progress
|
||||
elif self.lr_end_restart == 2:
|
||||
progress = int(progress) * 2 + (progress - 1)
|
||||
|
||||
return max(base_lr * 0.1, base_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||
|
||||
def get_lr(self, base_lr):
|
||||
assert self.num_iter >= 0
|
||||
if self.resume_step + self.resume_no_optimze > self.num_iter:
|
||||
bmt.print_rank("resume no optimize")
|
||||
return 0
|
||||
|
||||
if self.num_iter < self.warmup_iter:
|
||||
return self.get_lr_warmup(self.num_iter, base_lr)
|
||||
else:
|
||||
return self.get_lr_decay(self.num_iter, base_lr)
|
||||
|
||||
@property
|
||||
def current_lr(self):
|
||||
return self._current_lr
|
||||
|
||||
def step(self, num_iter=None) -> None:
|
||||
if num_iter is None:
|
||||
num_iter = self.num_iter + 1
|
||||
self.num_iter = num_iter
|
||||
|
||||
self._current_lr = self.get_lr(self._start_lr)
|
||||
for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
|
||||
group["lr"] = self.get_lr(base_lr)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"_start_lr": self._start_lr,
|
||||
"start_lr": self.start_lr,
|
||||
"warmup_iter": self.warmup_iter,
|
||||
"end_iter": self.end_iter,
|
||||
"num_iter": self.num_iter,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._start_lr = state_dict["_start_lr"]
|
||||
self.start_lr = state_dict["start_lr"]
|
||||
self.warmup_iter = state_dict["warmup_iter"]
|
||||
self.end_iter = state_dict["end_iter"]
|
||||
self.num_iter = state_dict["num_iter"]
|
||||
|
||||
self.step(self.num_iter)
|
||||
|
||||
|
||||
class WarmupStableDrop(bmt.lr_scheduler.WarmupLRScheduler):
|
||||
r"""
|
||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, start_lr, warmup_iter, end_iter, drop_iter=0, num_iter=0, resume_no_optimze=0
|
||||
) -> None:
|
||||
self.warmup_iter = warmup_iter
|
||||
self.end_iter = end_iter
|
||||
self.drop_iter = drop_iter
|
||||
self.optimizer = optimizer
|
||||
self.num_iter = num_iter
|
||||
self._current_lr = None
|
||||
self._start_lr = start_lr
|
||||
self.start_lr = []
|
||||
self.resume_step = num_iter
|
||||
self.resume_no_optimze = resume_no_optimze
|
||||
for group in self.optimizer.param_groups:
|
||||
self.start_lr.append(group["lr"])
|
||||
|
||||
self.step(self.num_iter)
|
||||
|
||||
def get_lr_warmup(self, num_iter, base_lr, warmup_iter) -> float:
|
||||
return base_lr * num_iter / warmup_iter
|
||||
|
||||
def get_lr_stable(self, num_iter, base_lr):
|
||||
return base_lr
|
||||
|
||||
def get_lr_drop(self, num_iter, base_lr):
|
||||
progress = (self.end_iter - num_iter) / self.drop_iter
|
||||
return base_lr * (0.1 + max(0.9 * (self.end_iter - num_iter) / self.drop_iter, 0))
|
||||
|
||||
def get_lr(self, base_lr):
|
||||
assert self.num_iter >= 0
|
||||
if self.resume_step + self.resume_no_optimze > self.num_iter:
|
||||
return self.get_lr_warmup(self.num_iter - self.resume_step, base_lr, self.resume_no_optimze)
|
||||
|
||||
if self.num_iter < self.warmup_iter:
|
||||
return self.get_lr_warmup(self.num_iter, base_lr, self.warmup_iter)
|
||||
|
||||
if self.num_iter > self.end_iter - self.drop_iter:
|
||||
return self.get_lr_drop(self.num_iter, base_lr)
|
||||
|
||||
return self.get_lr_stable(self.num_iter, base_lr)
|
||||
|
||||
@property
|
||||
def current_lr(self):
|
||||
return self._current_lr
|
||||
|
||||
def step(self, num_iter=None) -> None:
|
||||
if num_iter is None:
|
||||
num_iter = self.num_iter + 1
|
||||
self.num_iter = num_iter
|
||||
|
||||
self._current_lr = self.get_lr(self._start_lr)
|
||||
for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
|
||||
group["lr"] = self.get_lr(base_lr)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"_start_lr": self._start_lr,
|
||||
"start_lr": self.start_lr,
|
||||
"warmup_iter": self.warmup_iter,
|
||||
"end_iter": self.end_iter,
|
||||
"num_iter": self.num_iter,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._start_lr = state_dict["_start_lr"]
|
||||
self.start_lr = state_dict["start_lr"]
|
||||
self.warmup_iter = state_dict["warmup_iter"]
|
||||
self.end_iter = state_dict["end_iter"]
|
||||
self.num_iter = state_dict["num_iter"]
|
||||
|
||||
self.step(self.num_iter)
|
||||
|
||||
|
||||
|
||||
class WarmupStableExp(bmt.lr_scheduler.WarmupLRScheduler):
|
||||
r"""
|
||||
After a warmup period during which learning rate increases linearly between 0 and the start_lr,
|
||||
The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, start_lr, warmup_iter, drop_begin=-1, drop_rate=0.5, drop_iter=0, num_iter=0, resume_no_optimze=0
|
||||
) -> None:
|
||||
|
||||
self.warmup_iter = warmup_iter
|
||||
self.drop_iter = drop_iter
|
||||
self.optimizer = optimizer
|
||||
self.num_iter = num_iter
|
||||
self._current_lr = None
|
||||
self._start_lr = start_lr
|
||||
self.start_lr = []
|
||||
self.resume_step = num_iter
|
||||
self.resume_no_optimze = resume_no_optimze
|
||||
self.drop_begin = drop_begin
|
||||
self.drop_iter = drop_iter # here drop_iter is half-life
|
||||
self.drop_rate = drop_rate
|
||||
for group in self.optimizer.param_groups:
|
||||
self.start_lr.append(group["lr"])
|
||||
|
||||
self.step(self.num_iter)
|
||||
|
||||
def get_lr_warmup(self, num_iter, base_lr, warmup_iter) -> float:
|
||||
return base_lr * num_iter / warmup_iter
|
||||
|
||||
def get_lr_stable(self, num_iter, base_lr):
|
||||
return base_lr
|
||||
|
||||
def get_lr_drop(self, num_iter, base_lr):
|
||||
if self.drop_iter < 0:
|
||||
return base_lr
|
||||
progress = (num_iter - self.drop_begin) / self.drop_iter
|
||||
return base_lr * (self.drop_rate ** progress)
|
||||
|
||||
def get_lr(self, base_lr):
|
||||
assert self.num_iter >= 0
|
||||
if self.resume_step + self.resume_no_optimze > self.num_iter:
|
||||
return self.get_lr_warmup(self.num_iter - self.resume_step, base_lr, self.resume_no_optimze)
|
||||
|
||||
if self.num_iter < self.warmup_iter:
|
||||
return self.get_lr_warmup(self.num_iter, base_lr, self.warmup_iter)
|
||||
|
||||
if self.num_iter > self.drop_begin:
|
||||
return self.get_lr_drop(self.num_iter, base_lr)
|
||||
|
||||
return self.get_lr_stable(self.num_iter, base_lr)
|
||||
|
||||
@property
|
||||
def current_lr(self):
|
||||
return self._current_lr
|
||||
|
||||
def step(self, num_iter=None) -> None:
|
||||
if num_iter is None:
|
||||
num_iter = self.num_iter + 1
|
||||
self.num_iter = num_iter
|
||||
|
||||
self._current_lr = self.get_lr(self._start_lr)
|
||||
for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
|
||||
group["lr"] = self.get_lr(base_lr)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"_start_lr": self._start_lr,
|
||||
"start_lr": self.start_lr,
|
||||
"warmup_iter": self.warmup_iter,
|
||||
"drop_begin": self.drop_begin,
|
||||
"num_iter": self.num_iter,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._start_lr = state_dict["_start_lr"]
|
||||
self.start_lr = state_dict["start_lr"]
|
||||
self.warmup_iter = state_dict["warmup_iter"]
|
||||
self.drop_begin = state_dict["drop_begin"]
|
||||
self.num_iter = state_dict["num_iter"]
|
||||
|
||||
self.step(self.num_iter)
|
|
@ -1,5 +0,0 @@
|
|||
from .log import logger
|
||||
from .log import LogManager
|
||||
from .object import allgather_objects
|
||||
from .config import Config
|
||||
from .gradient_shrink import gradient_shrink
|
|
@ -1,159 +0,0 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import bmtrain as bmt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BitSet:
|
||||
def __init__(self, size=1024**2):
|
||||
self.size = size
|
||||
self.bitset = np.zeros(self.size, dtype=bool)
|
||||
|
||||
def _ensure_capacity(self, num):
|
||||
"""确保bitset有足够的容量来存储指定的数字"""
|
||||
if num >= self.size:
|
||||
# 扩展bitset大小
|
||||
new_size = max(num + 1, self.size * 2)
|
||||
new_bitset = np.zeros(new_size, dtype=bool)
|
||||
new_bitset[: self.size] = self.bitset
|
||||
self.bitset = new_bitset
|
||||
self.size = new_size
|
||||
bmt.print_rank("enlarge size to {}".format(self.size))
|
||||
|
||||
def add(self, num):
|
||||
"""向bitset中添加一个数字"""
|
||||
self._ensure_capacity(num)
|
||||
self.bitset[num] = True
|
||||
|
||||
def remove(self, num):
|
||||
"""从bitset中移除一个数字"""
|
||||
if num < self.size:
|
||||
self.bitset[num] = False
|
||||
|
||||
def contains(self, num):
|
||||
"""检查bitset是否包含某个数字"""
|
||||
return num < self.size and self.bitset[num]
|
||||
|
||||
def __contains__(self, num):
|
||||
return self.contains(num)
|
||||
|
||||
def update(self, iterable_or_bitset):
|
||||
"""使用可迭代对象或另一个BitSet中的元素更新当前bitset"""
|
||||
if isinstance(iterable_or_bitset, BitSet):
|
||||
# 如果参数是BitSet,则使用numpy的向量化操作更新
|
||||
self._ensure_capacity(iterable_or_bitset.size)
|
||||
self.bitset[: iterable_or_bitset.size] |= iterable_or_bitset.bitset
|
||||
else:
|
||||
# 如果参数是可迭代对象,则遍历并添加每个元素
|
||||
for num in iterable_or_bitset:
|
||||
self.add(num)
|
||||
|
||||
def __sub__(self, other):
|
||||
"""实现减法运算符,使用numpy向量化操作来高效地创建一个新的bitset"""
|
||||
# 创建一个新的bitset实例
|
||||
result = BitSet(max(self.size, other.size))
|
||||
# 使用numpy的向量化逻辑运算
|
||||
result.bitset[: self.size] = self.bitset & ~other.bitset[: self.size]
|
||||
return result
|
||||
|
||||
def __isub__(self, other):
|
||||
"""实现就地减法运算符,利用numpy向量化操作进行高效的元素移除"""
|
||||
# 首先确保other的大小不超过当前bitset的大小
|
||||
min_size = min(self.size, other.size)
|
||||
# 使用numpy的向量化逻辑运算进行元素移除
|
||||
self.bitset[:min_size] &= ~other.bitset[:min_size]
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
"""返回bitset的字符串表示,列出所有为真的位的索引"""
|
||||
# 找出所有为真的位的索引
|
||||
true_indices = np.where(self.bitset)[0]
|
||||
# 将这些索引转换为字符串并用逗号分隔
|
||||
indices_str = ", ".join(map(str, true_indices))
|
||||
return f"BitSet({indices_str})"
|
||||
|
||||
def __len__(self):
|
||||
"""返回bitset中为真的元素个数"""
|
||||
return self.bitset.sum()
|
||||
|
||||
def capacity(self):
|
||||
return self.size
|
||||
|
||||
def density(self):
|
||||
return len(self) / self.size
|
||||
|
||||
def memory_usage(self):
|
||||
"""返回bitset所占用的内存大小,以KB、MB或GB为单位"""
|
||||
bytes_usage = self.bitset.nbytes
|
||||
if bytes_usage < 1024:
|
||||
return f"{bytes_usage} B"
|
||||
elif bytes_usage < 1024**2:
|
||||
return f"{bytes_usage / 1024:.2f} KB"
|
||||
elif bytes_usage < 1024**3:
|
||||
return f"{bytes_usage / 1024**2:.2f} MB"
|
||||
else:
|
||||
return f"{bytes_usage / 1024**3:.2f} GB"
|
||||
|
||||
def to_list(self):
|
||||
"""返回一个包含所有为真位索引的列表"""
|
||||
return list(np.where(self.bitset)[0])
|
||||
|
||||
def save(self, filename):
|
||||
"""将bitset保存到文件"""
|
||||
|
||||
def random_hash():
|
||||
"""返回一个随机哈希值"""
|
||||
return random.randint(0, 2**64 - 1)
|
||||
|
||||
filename_with_suffix = filename + ".{}.npy".format(random_hash())
|
||||
dirname = os.path.dirname(filename_with_suffix)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
np.save(filename_with_suffix, self.bitset)
|
||||
return os.path.basename(filename_with_suffix) # 返回最后的名字,不带前缀,支持tranfer项目
|
||||
|
||||
@classmethod
|
||||
def load(cls, filename_with_suffix):
|
||||
"""从文件加载bitset并创建一个新的BitSet实例"""
|
||||
bitset_array = np.load(filename_with_suffix)
|
||||
bitset = cls(bitset_array.size)
|
||||
bitset.bitset = bitset_array
|
||||
return bitset
|
||||
|
||||
|
||||
def bitset_diff(normal_set, bitset):
|
||||
"""返回存在于normal_set中但不在bitset中的元素集合"""
|
||||
ret = {elem for elem in normal_set if not bitset.contains(elem)}
|
||||
return ret
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例使用
|
||||
bitset1 = BitSet(1024)
|
||||
bitset1.update([100, 200, 300, 1023])
|
||||
|
||||
bitset2 = BitSet(1024)
|
||||
bitset2.update([100, 400, 1023])
|
||||
|
||||
result_bitset = bitset1 - bitset2
|
||||
print(100 in result_bitset) # 应该输出False
|
||||
print(200 in result_bitset) # 应该输出True
|
||||
print(300 in result_bitset) # 应该输出True
|
||||
print(1023 in result_bitset) # 应该输出False
|
||||
|
||||
bitset1 -= bitset2
|
||||
print(result_bitset) # BitSet(200, 300)
|
||||
print(bitset1) # BitSet(200, 300)
|
||||
print(bitset2) # BitSet(100, 400, 1023)
|
||||
|
||||
bitsetlarge = BitSet(1024**3)
|
||||
print(len(bitsetlarge), bitsetlarge.capacity(), bitsetlarge.density(), bitset1.density())
|
||||
print("BitSet memory usage:", bitsetlarge.memory_usage())
|
||||
|
||||
print(bitset_diff({100, 200}, bitset2))
|
||||
|
||||
bitset1.update(bitset2)
|
||||
bitsetlarge.add(52260134)
|
||||
bitset2.update(bitsetlarge)
|
||||
print(bitset1) # BitSet(100, 200, 300, 400, 1023)
|
||||
print(bitset2) # BitSet(100, 400, 1023, 52260134)
|
|
@ -1,531 +0,0 @@
|
|||
# COPIED from exporter, for scaling project special use.
|
||||
# Author: shengdinghu@gmail.com
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import glob
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import hashlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import bmtrain as bmt
|
||||
from bmtrain.distributed import all_reduce, all_gather
|
||||
import torch
|
||||
|
||||
from fm9g.utils import allgather_objects
|
||||
from fm9g.utils.bitset import BitSet
|
||||
|
||||
from .log import logger
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
|
||||
def _save_artifacts(
|
||||
model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt=None, final_save=False
|
||||
):
|
||||
"""
|
||||
Export model artifacts. Mainly for the purpose of asynchronous saving.
|
||||
"""
|
||||
try:
|
||||
# platform_cfg = get_platform_cfg()
|
||||
raise ValueError("No platform_cfg")
|
||||
except ValueError:
|
||||
platform_cfg = None
|
||||
export_model_dir = os.path.join(args.save, str(global_step))
|
||||
else:
|
||||
export_model_dir = (
|
||||
args.save if final_save else platform_cfg.gen_export_ckpt_dir_for_step(args.save, global_step)
|
||||
)
|
||||
os.makedirs(export_model_dir, exist_ok=True)
|
||||
base_file_name = f"{args.save_name}-{global_step}" if global_step > -1 else args.save_name
|
||||
bmt.print_rank(f"start to export ckpt, save_dir={export_model_dir}, file prefix={base_file_name}")
|
||||
|
||||
# model checkpoint
|
||||
ckpt_path = os.path.join(export_model_dir, base_file_name + ".pt")
|
||||
# opt 文件仅用于项目内续训,不需要导出为模型版本文件
|
||||
opt_path = os.path.join(
|
||||
export_model_dir,
|
||||
args.save_name + ("-%d.rank-%d.opt" % (global_step, bmt.rank())),
|
||||
)
|
||||
|
||||
if bmt.rank() == 0:
|
||||
torch.save(model_state, ckpt_path)
|
||||
bmt.print_rank(f"Save checkpoint successfully, ckpt file path: {ckpt_path}")
|
||||
torch.save(opt_state, opt_path)
|
||||
print(f"Save optimizer state successfully, opt file path: {opt_path}")
|
||||
|
||||
del model_state
|
||||
del opt_state
|
||||
|
||||
# 保存统计量
|
||||
if log_ckpt is not None:
|
||||
bmt.print_rank("save log ckpt ...")
|
||||
with open(os.path.join(export_model_dir, base_file_name + ".log_ckpt"), "w") as log_ckpt_file:
|
||||
json.dump(log_ckpt, log_ckpt_file)
|
||||
|
||||
logger.info(f"Starting saving dataset state. ")
|
||||
|
||||
dataset_ckpt_path = os.path.join(export_model_dir, "dataset_ckpt")
|
||||
os.makedirs(dataset_ckpt_path, exist_ok=True)
|
||||
if bmt.config["tp_rank"] == 0:
|
||||
p_dataset = os.path.join(dataset_ckpt_path, f"dataset_{bmt.rank()}.data")
|
||||
dataloader.save_state_dict(p_dataset)
|
||||
|
||||
if bmt.rank() == 0:
|
||||
# config 和 vocabs 和模型文件一起存储
|
||||
model_config.save_pretrained(export_model_dir)
|
||||
try:
|
||||
tokenizer.save_pretrained(export_model_dir)
|
||||
except:
|
||||
bmt.print_rank("No save pretrained method for tokenizer")
|
||||
shutil.copy(args.tokenizer_path, export_model_dir)
|
||||
|
||||
# 存储完所有文件后调用
|
||||
if platform_cfg is not None:
|
||||
platform_cfg.finalize_model_save(export_model_dir, base_file_name)
|
||||
else:
|
||||
bmt.print_rank("No platform_cfg, skip finalize_model_save, may be not have .success file")
|
||||
logger.info(f"Successfully save model files: {os.listdir(export_model_dir)}")
|
||||
|
||||
# 每个进程都在export_model_dir写一个.save_done文件,用于判断是否所有进程都保存完毕
|
||||
# 直接用常规写文件方式
|
||||
os.makedirs(os.path.join(export_model_dir, "save_status"), exist_ok=True)
|
||||
with open(os.path.join(export_model_dir, f"save_status/{bmt.rank()}.save_done"), "w") as f:
|
||||
f.write("done")
|
||||
|
||||
# 等待所有进程都保存完毕, 不能用synchronize
|
||||
if bmt.rank() == 0:
|
||||
while True:
|
||||
if len(os.listdir(os.path.join(export_model_dir, "save_status"))) == bmt.world_size():
|
||||
break
|
||||
time.sleep(1)
|
||||
bmt.print_rank(f"All saved! Rank 0 Begin to merge dataset ckpt to {dataset_ckpt_path}/dataset_.data")
|
||||
merge_dataset_ckpts(export_model_dir, args.parallel_load_datastate//2)
|
||||
else:
|
||||
bmt.print_rank(f"rank-{bmt.rank()} done, wait for rank0 to merge dataset ckpt")
|
||||
|
||||
|
||||
|
||||
def export(
|
||||
model: torch.nn.Module,
|
||||
dataloader,
|
||||
tokenizer,
|
||||
optimizer: bmt.optim.AdamOffloadOptimizer,
|
||||
global_step,
|
||||
args,
|
||||
log_ckpt=None,
|
||||
final_save=False,
|
||||
async_save=False,
|
||||
):
|
||||
"""
|
||||
一次 ckpt 保存:
|
||||
/{args.save}/
|
||||
├── job_{job_id}_ckpt_{global_step}/ # checkpoint 导出为模型版本时,job_{job_id}_ckpt_{global_step}/ 路径下文件会一起导出,创建一个模型组版本
|
||||
├── config.json
|
||||
├── vocabs.txt
|
||||
├── {args.save_name}-{global_step}.rank-0.opt
|
||||
├── {args.save_name}-{global_step}.rank-n.opt
|
||||
├── {args.save_name}-{global_step}.pt
|
||||
├── {args.save_name}-{global_step}.data
|
||||
├── {args.save_name}-{global_step}.data.json
|
||||
├── {args.save_name}-{global_step}.success
|
||||
└── {args.save_name}-{global_step}.log_ckpt
|
||||
|
||||
"""
|
||||
|
||||
bmt.synchronize()
|
||||
model_state = bmt.store._save_to_rank0(model)
|
||||
opt_state = deepcopy(optimizer.state_dict())
|
||||
model_config = model.config
|
||||
|
||||
if async_save:
|
||||
# Save artifacts asynchronously
|
||||
save_proc = mp.Process(
|
||||
target=_save_artifacts,
|
||||
args=(model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt, final_save),
|
||||
)
|
||||
save_proc.start()
|
||||
else:
|
||||
_save_artifacts(
|
||||
model_state, dataloader, tokenizer, opt_state, global_step, args, model_config, log_ckpt, final_save
|
||||
)
|
||||
|
||||
|
||||
def load_model_ckpt(args, model):
|
||||
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
|
||||
if args.load.endswith(".pt"):
|
||||
checkpoint_file = args.load
|
||||
else:
|
||||
# a directory
|
||||
load_path = args.load
|
||||
checkpoint_files = [file for file in os.listdir(load_path) if file.endswith(".pt")]
|
||||
assert len(checkpoint_files) == 1, "None or multiple .pt found in {}".format(load_path)
|
||||
checkpoint_file = os.path.join(load_path, checkpoint_files[0])
|
||||
bmt.print_rank("args.load is not None, start to load checkpoints from" + checkpoint_file)
|
||||
|
||||
bmt.load(model, checkpoint_file)
|
||||
return model
|
||||
|
||||
|
||||
def _legacy_load_optimizer_ckpt(args, optimizer):
|
||||
bmt.print_rank("Use legacy optimizer ckpt!")
|
||||
if args.load.endswith(".pt"):
|
||||
optimizer_path = os.path.dirname(os.path.dirname(args.load))
|
||||
else:
|
||||
optimizer_path = os.path.dirname(args.load)
|
||||
bmt.print_rank(os.listdir(optimizer_path))
|
||||
start = time.time()
|
||||
bmt.print_rank(
|
||||
"{}".format(
|
||||
sum(
|
||||
[
|
||||
1
|
||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||
else 0
|
||||
for i in os.listdir(optimizer_path)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
1
|
||||
if i.find(".opt") != -1 and i.find("-{}.rank".format(args.start_step % (args.save_iters * 5))) != -1
|
||||
else 0
|
||||
for i in os.listdir(optimizer_path)
|
||||
]
|
||||
)
|
||||
== bmt.world_size()
|
||||
):
|
||||
pattern = "-{}.rank-{}.opt".format(args.start_step % (args.save_iters * 5), bmt.rank())
|
||||
bmt.print_rank("Will load opt that matches pattern: {}".format(pattern))
|
||||
for file_name in os.listdir(optimizer_path):
|
||||
if file_name.find(pattern) != -1:
|
||||
bmt.print_rank("start to load grad ckpt {}".format(file_name))
|
||||
states = torch.load(os.path.join(optimizer_path, file_name))
|
||||
optimizer.load_state_dict(states)
|
||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||
return optimizer
|
||||
|
||||
|
||||
def load_optimizer_ckpt(args, optimizer):
|
||||
if args.load.endswith(".pt"):
|
||||
optimizer_path = os.path.dirname(args.load)
|
||||
else:
|
||||
# a directory
|
||||
optimizer_path = args.load
|
||||
start = time.time()
|
||||
opt_num = sum(
|
||||
[1 if re.search(r"-{}.rank-\d+.opt".format(args.start_step), i) else 0 for i in os.listdir(optimizer_path)]
|
||||
)
|
||||
|
||||
bmt.print_rank(f"Opt file num: {opt_num}")
|
||||
if opt_num == 0:
|
||||
return _legacy_load_optimizer_ckpt(args, optimizer)
|
||||
|
||||
if opt_num == 0:
|
||||
return _legacy_load_optimizer_ckpt(args, optimizer)
|
||||
|
||||
if opt_num == bmt.world_size():
|
||||
file_name = os.path.join(
|
||||
optimizer_path,
|
||||
args.save_name + "-{}.rank-{}.opt".format(args.start_step, bmt.rank()),
|
||||
)
|
||||
|
||||
if os.path.exists(file_name):
|
||||
print("rank {} start to load grad ckpt {}".format(bmt.rank(), file_name))
|
||||
states = torch.load(file_name)
|
||||
optimizer.load_state_dict(states)
|
||||
# optimizer_external_load_state_dict(optimizer, states, args.grad_ckpt_num)
|
||||
logger.info("load grad in {:.2f}s".format(time.time() - start))
|
||||
return optimizer
|
||||
|
||||
|
||||
def load_dataloader_ckpt(args, mixed_dataset):
|
||||
load_success = _load_distributed_dataset_state(args, mixed_dataset)
|
||||
if not load_success:
|
||||
logger.info("load from distributed data state dict fail, try to load from single data state_dict")
|
||||
_load_dataloader_ckpt(args, mixed_dataset)
|
||||
|
||||
|
||||
def _load_dataloader_ckpt(args, mixed_dataset):
|
||||
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
|
||||
if args.load.endswith(".pt"):
|
||||
load_path = os.path.dirname(args.load)
|
||||
else:
|
||||
load_path = args.load
|
||||
dataset_states_path = [file for file in os.listdir(load_path) if file.endswith(".data")]
|
||||
assert len(dataset_states_path) == 1, "None or multiple .data found in {}, file list: {}".format(
|
||||
load_path, dataset_states_path
|
||||
)
|
||||
dataset_states_path = dataset_states_path[0]
|
||||
dataset_states_path = os.path.join(load_path, dataset_states_path)
|
||||
bmt.print_rank("args.load is not None, start to load data ckpt from " + dataset_states_path)
|
||||
dataset_states = torch.load(dataset_states_path)
|
||||
missing = mixed_dataset.load_state_dict(dataset_states)
|
||||
if len(missing) > 0:
|
||||
bmt.print_rank("missing keys when loading dataset states: ", missing)
|
||||
|
||||
|
||||
def load_trace_ckpt(args, dataloader):
|
||||
"""args.load 是一个到/{args.save}/job_{job_id}_ckpt_{global_step}/ 的路径"""
|
||||
return dataloader
|
||||
|
||||
|
||||
def load_log_ckpt(args):
|
||||
if args.load.endswith(".pt"):
|
||||
load_path = os.path.dirname(args.load)
|
||||
else:
|
||||
load_path = args.load
|
||||
log_ckpt_paths = [file for file in os.listdir(load_path) if file.endswith(".log_ckpt")]
|
||||
assert len(log_ckpt_paths) <= 1, "Multiple .data found in {}".format(load_path)
|
||||
if len(log_ckpt_paths) == 0:
|
||||
bmt.print_rank("No log ckpt is found in {}".format(load_path))
|
||||
return {}
|
||||
log_ckpt_path = os.path.join(load_path, log_ckpt_paths[0])
|
||||
bmt.print_rank("args.load is not None, start to load log ckpt from " + log_ckpt_path)
|
||||
with open(log_ckpt_path, "r") as log_ckpt_file:
|
||||
log_ckpt = json.load(log_ckpt_file)
|
||||
return log_ckpt
|
||||
|
||||
|
||||
def _load_distributed_dataset_state(args, mixed_dataset):
|
||||
rank = bmt.rank()
|
||||
logger.info(f"rank-{rank} -> [start]loading dataset states")
|
||||
if args.load.endswith(".pt"):
|
||||
load_dir = os.path.dirname(args.load)
|
||||
else:
|
||||
load_dir = args.load
|
||||
|
||||
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_ckpt/dataset_*.data")))
|
||||
if len(p_datasets) == 0: # 向后兼容
|
||||
bmt.print_rank("load_from_orginal_dataset_ckpt_folder")
|
||||
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_*.data")))
|
||||
|
||||
all_state_dict = dict()
|
||||
|
||||
def load_and_aggregate(p_dataset):
|
||||
"""Map func for loading and aggregating dataset states to all_state_dict"""
|
||||
|
||||
def new_key_init(all_state_dict, key, state_dict):
|
||||
all_state_dict[key] = {}
|
||||
for second_key in state_dict[key]:
|
||||
if second_key == "used":
|
||||
all_state_dict[key]["used"] = BitSet(1024)
|
||||
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
||||
else:
|
||||
all_state_dict[key][second_key] = state_dict[key][second_key]
|
||||
|
||||
def load_state_dict_with_npy(p_dataset):
|
||||
state_dict = torch.load(p_dataset)
|
||||
for key in state_dict:
|
||||
if isinstance(state_dict[key]["used"], str):
|
||||
if os.path.basename(state_dict[key]["used"]) == state_dict[key]["used"]: # 如果只有文件名
|
||||
state_dict[key]["used"] = BitSet.load(os.path.join(load_dir, "dataset_ckpt", os.path.basename(state_dict[key]["used"])))
|
||||
else: # 如果是完整路径(向后兼容)
|
||||
state_dict[key]["used"] = BitSet.load(state_dict[key]["used"])
|
||||
return state_dict
|
||||
|
||||
print(f"Loading {p_dataset}...")
|
||||
state_dict = load_state_dict_with_npy(p_dataset)
|
||||
dataset_locks = {}
|
||||
for key in state_dict.keys():
|
||||
if key not in dataset_locks:
|
||||
with lock:
|
||||
if key not in dataset_locks:
|
||||
dataset_locks[key] = threading.Lock()
|
||||
|
||||
with dataset_locks[key]:
|
||||
if key in all_state_dict:
|
||||
if all_state_dict[key]["exhausted"]:
|
||||
continue
|
||||
elif state_dict[key]["exhausted"]:
|
||||
all_state_dict[key]["exhausted"] = True
|
||||
all_state_dict[key]["used"] = BitSet(1024)
|
||||
else:
|
||||
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
||||
else:
|
||||
new_key_init(all_state_dict, key, state_dict)
|
||||
bmt.print_rank(f"[done]loaded dataset states: {p_dataset}")
|
||||
|
||||
if p_datasets:
|
||||
rank = bmt.rank()
|
||||
lst_time = time.time()
|
||||
if rank == 0:
|
||||
with ThreadPoolExecutor(max_workers=args.parallel_load_datastate) as executor:
|
||||
executor.map(load_and_aggregate, p_datasets)
|
||||
|
||||
logger.info(
|
||||
f"rank-{rank} -> load dataset from {len(p_datasets)} .data files. Time: {time.time() - lst_time:.2f}s"
|
||||
)
|
||||
# Broadcast the tensor to other process
|
||||
lst_time = time.time()
|
||||
all_state_dict = bmt.store.broadcast_object(all_state_dict, comm=bmt.config["comm"], src=0)
|
||||
logger.info(f"rank-{rank} -> broadcast dataset from rank-0 to other ranks. Time: {time.time() - lst_time:.2f}s")
|
||||
lst_time = time.time()
|
||||
|
||||
missing_keys = mixed_dataset.load_state_dict(all_state_dict)
|
||||
logger.info(f"rank-{rank} -> load mixed dataset state dict. Time: {time.time() - lst_time:.2f}s")
|
||||
if missing_keys:
|
||||
logger.info(
|
||||
f"rank-{rank} -> load dataset from {len(p_datasets)} .data files with {os.path.join(load_dir, 'dataset_ckpt', 'dataset*.pt')} : {p_datasets}, missing tasks: {missing_keys}"
|
||||
)
|
||||
else:
|
||||
state_info = {
|
||||
k: {
|
||||
"ave_tokens": s.get("ave_tokens", -1),
|
||||
"set_info": "mem:{}|density:{:.4f}|len:{}".format(
|
||||
s["used"].memory_usage(), s["used"].density(), len(s["used"])
|
||||
),
|
||||
}
|
||||
for k, s in all_state_dict.items()
|
||||
}
|
||||
logger.info(
|
||||
f"rank-{rank} -> load dataset from {len(p_datasets)} files with {os.path.join(load_dir, 'dataset*.pt')}. Info: {state_info}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.info(f"No dataset*.data found. p_datasets: {p_datasets}")
|
||||
return False
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def flatten_stats(stats, parent_key="", separator="/"):
|
||||
items = []
|
||||
for key, value in stats.items():
|
||||
new_key = f"{parent_key}{separator}{key}" if parent_key else key
|
||||
if isinstance(value, dict):
|
||||
items.extend(flatten_stats(value, new_key, separator).items())
|
||||
if isinstance(value, list):
|
||||
items.append((new_key, json.dumps(value)))
|
||||
else:
|
||||
items.append((new_key, value))
|
||||
return OrderedDict(items)
|
||||
|
||||
|
||||
def save_every_step_stats(stats, path):
|
||||
flattened_stats = flatten_stats(stats)
|
||||
|
||||
os.makedirs(os.path.join(path, "train_stats/"), exist_ok=True)
|
||||
|
||||
# Function to get the current file ID and size
|
||||
def get_current_file_id_and_size(path):
|
||||
id = 0
|
||||
while True:
|
||||
file_path = os.path.join(path, f"train_stats/{id}.jsonl")
|
||||
if not os.path.exists(file_path):
|
||||
return id, 0
|
||||
else:
|
||||
size = os.path.getsize(file_path)
|
||||
if size > 19 * 1024 * 1024: # Size in bytes (20 MB)
|
||||
id += 1
|
||||
else:
|
||||
return id, size
|
||||
|
||||
# Get the current file id and its size
|
||||
current_id, file_size = get_current_file_id_and_size(path)
|
||||
|
||||
# Generate the file path
|
||||
file_path = os.path.join(path, f"train_stats/{current_id}.jsonl")
|
||||
|
||||
# Write the flattened stats to the file
|
||||
with open(file_path, "a") as json_file:
|
||||
json_file.write(json.dumps(flattened_stats) + "\n")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def merge_dataset_ckpts(load_dir, parallel_load_datastate):
|
||||
p_datasets = sorted(glob.glob(os.path.join(load_dir, "dataset_ckpt/dataset_*.data")))
|
||||
p_datasets= [x for x in p_datasets if "dataset_ckpt/dataset_.data" not in x]
|
||||
|
||||
bmt.print_rank(f"Files before merge (total num {len(p_datasets)}): {p_datasets}")
|
||||
|
||||
all_state_dict = dict()
|
||||
|
||||
def load_and_aggregate(p_dataset):
|
||||
"""Map func for loading and aggregating dataset states to all_state_dict"""
|
||||
|
||||
def new_key_init(all_state_dict, key, state_dict):
|
||||
all_state_dict[key] = {}
|
||||
for second_key in state_dict[key]:
|
||||
if second_key == "used":
|
||||
all_state_dict[key]["used"] = BitSet(1024)
|
||||
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
||||
else:
|
||||
all_state_dict[key][second_key] = state_dict[key][second_key]
|
||||
|
||||
def load_state_dict_with_npy(p_dataset):
|
||||
state_dict = torch.load(p_dataset)
|
||||
for key in state_dict:
|
||||
if isinstance(state_dict[key]["used"], str):
|
||||
state_dict[key]["used"] = BitSet.load(os.path.join(load_dir, "dataset_ckpt", os.path.basename(state_dict[key]["used"])))
|
||||
return state_dict
|
||||
|
||||
print(f"Loading {p_dataset}...")
|
||||
state_dict = load_state_dict_with_npy(p_dataset)
|
||||
dataset_locks = {}
|
||||
|
||||
for key in state_dict.keys():
|
||||
if key not in dataset_locks:
|
||||
with lock:
|
||||
if key not in dataset_locks:
|
||||
dataset_locks[key] = threading.Lock()
|
||||
with dataset_locks[key]:
|
||||
if key in all_state_dict:
|
||||
if all_state_dict[key]["exhausted"]:
|
||||
continue
|
||||
elif state_dict[key]["exhausted"]:
|
||||
all_state_dict[key]["exhausted"] = True
|
||||
all_state_dict[key]["used"] = BitSet(1024)
|
||||
else:
|
||||
all_state_dict[key]["used"].update(state_dict[key]["used"])
|
||||
else:
|
||||
new_key_init(all_state_dict, key, state_dict)
|
||||
del state_dict
|
||||
print(f"Merged {p_dataset}...")
|
||||
|
||||
if p_datasets:
|
||||
# with ThreadPoolExecutor(max_workers=args.parallel_load_datastate) as executor:
|
||||
with ThreadPoolExecutor(max_workers=parallel_load_datastate) as executor: # smaller than normal load to avoid OOM
|
||||
|
||||
executor.map(load_and_aggregate, p_datasets)
|
||||
# load_and_aggregate(p_datasets[0])
|
||||
|
||||
# Broadcast the tensor to other process
|
||||
save_path = os.path.join(load_dir, "dataset_ckpt", "dataset_.data")
|
||||
|
||||
# save_path = os.path.join("dataset_.data")
|
||||
for key in all_state_dict:
|
||||
npy_path = all_state_dict[key]["used"].save(save_path)
|
||||
all_state_dict[key]["used"] = npy_path
|
||||
bmt.print_rank(f"All state_dict after merge {all_state_dict}")
|
||||
torch.save(all_state_dict, save_path)
|
||||
|
||||
# Find all files that match the pattern
|
||||
files_to_remove = glob.glob(os.path.join(load_dir, "dataset_ckpt", "dataset_*.data*"))
|
||||
|
||||
# Remove the files
|
||||
for file in files_to_remove:
|
||||
if "dataset_.data" not in file:
|
||||
os.remove(file)
|
||||
|
||||
files_after_merge = os.listdir(os.path.join(load_dir, "dataset_ckpt"))
|
||||
bmt.print_rank(f"Files after merge: {files_after_merge}")
|
|
@ -1,16 +0,0 @@
|
|||
import torch
|
||||
|
||||
|
||||
class OpGradientShrink(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, alpha: float):
|
||||
ctx.alpha = alpha
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output * ctx.alpha, None
|
||||
|
||||
|
||||
def gradient_shrink(x: torch.Tensor, alpha: float = 0.1):
|
||||
return OpGradientShrink.apply(x, alpha)
|
|
@ -1,117 +0,0 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time as time_
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
|
||||
def _get_logger():
|
||||
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
|
||||
log = logging.getLogger("__name__")
|
||||
log.setLevel(log_level)
|
||||
log.propagate = False
|
||||
|
||||
node_name = os.getenv("NODE_NAME", "jeeves-hpc-gpu00")
|
||||
|
||||
fmt = f"[%(levelname)s][%(asctime)s][{node_name}][%(filename)s:%(lineno)d:%(process)d] - %(message)s"
|
||||
formatter = logging.Formatter(fmt, datefmt="%Y-%m-%d %H:%M:%S")
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(log_level)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
log.addHandler(handler)
|
||||
|
||||
return log
|
||||
|
||||
|
||||
# 日志句柄
|
||||
logger = _get_logger()
|
||||
|
||||
|
||||
class LogManager:
|
||||
def __init__(self, path: str):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
self.path = path
|
||||
|
||||
now = self.get_log_time()
|
||||
latest_log: Union[Dict[str, Any], None] = None
|
||||
for _ in range(15):
|
||||
log_name = self.get_log_name(now)
|
||||
if os.path.exists(log_name):
|
||||
with open(log_name, "r") as flog:
|
||||
lines = flog.readlines()
|
||||
if lines:
|
||||
latest_log = json.loads(lines[-1])
|
||||
break
|
||||
now -= datetime.timedelta(days=1)
|
||||
|
||||
if latest_log is None:
|
||||
self.global_token_pass = 0
|
||||
else:
|
||||
self.global_token_pass = latest_log["token pass"]
|
||||
|
||||
def get_log_time(self) -> datetime.datetime:
|
||||
return datetime.datetime.utcnow() + datetime.timedelta(hours=16)
|
||||
|
||||
def get_log_name(self, now: Optional[datetime.datetime] = None):
|
||||
if now is None:
|
||||
now = self.get_log_time()
|
||||
return os.path.join(self.path, "log.%s.txt" % now.strftime("%Y%m%d"))
|
||||
|
||||
def write(
|
||||
self,
|
||||
time: float,
|
||||
iteration: int,
|
||||
loss: float,
|
||||
lr: float,
|
||||
lr_scale: float,
|
||||
time_usage: Dict[str, float],
|
||||
mem_usage: Dict[str, Tuple[float, float]],
|
||||
avg_time: float,
|
||||
token_max: float,
|
||||
token_pass: float,
|
||||
throughout: float,
|
||||
grad_norm: float,
|
||||
mask_max: float,
|
||||
num_gpus: int,
|
||||
task_loss: Dict[str, float],
|
||||
model_inspect: Optional[Any] = None,
|
||||
):
|
||||
with open(self.get_log_name(), "a") as fp:
|
||||
while True:
|
||||
try:
|
||||
ret = {
|
||||
"time": time,
|
||||
"iter": iteration,
|
||||
"loss": loss,
|
||||
"lr": lr,
|
||||
"lr scale": int(lr_scale),
|
||||
"time usage": time_usage,
|
||||
"mem usage": mem_usage,
|
||||
"avg time (s)": avg_time,
|
||||
"token/max": token_max,
|
||||
"token pass": token_pass + self.global_token_pass,
|
||||
"throughout (token/s)": throughout,
|
||||
"grad_norm": grad_norm,
|
||||
"mask/max": mask_max,
|
||||
"num_gpus": num_gpus,
|
||||
"task_loss": task_loss,
|
||||
}
|
||||
if model_inspect is not None:
|
||||
ret["model_inspect"] = model_inspect
|
||||
print(ret)
|
||||
fp.write(json.dumps(ret) + "\n")
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Error: writing info list!")
|
||||
time_.sleep(10)
|
|
@ -1,29 +0,0 @@
|
|||
import pickle
|
||||
|
||||
import bmtrain as bmt
|
||||
import torch
|
||||
|
||||
|
||||
def allgather_objects(obj):
|
||||
if bmt.world_size() == 1:
|
||||
return [obj]
|
||||
|
||||
with torch.no_grad():
|
||||
data_bytes: bytes = pickle.dumps(obj)
|
||||
data_length: int = len(data_bytes)
|
||||
|
||||
gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long)
|
||||
gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu()
|
||||
max_data_length = gathered_length.max().item()
|
||||
|
||||
gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda")
|
||||
byte_storage = torch.ByteStorage.from_buffer(data_bytes)
|
||||
gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage)
|
||||
|
||||
gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu()
|
||||
|
||||
ret = []
|
||||
for i in range(gathered_data.size(0)):
|
||||
data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes()
|
||||
ret.append(pickle.loads(data_bytes))
|
||||
return ret
|
|
@ -1,72 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
|
||||
#
|
||||
# @author: hsd9026 <shengdinghu@gmail.com>
|
||||
# @date: 2023/07/07
|
||||
#
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from .log import logger
|
||||
|
||||
|
||||
def num_parameters(model):
|
||||
"""Return the number of parameters of a model"""
|
||||
total_params = 0
|
||||
for param_name, param in model.state_dict().items():
|
||||
# print(param_name, param.numel())
|
||||
total_params += param.numel()
|
||||
return total_params
|
||||
|
||||
|
||||
def num_non_embedding_parameters(model):
|
||||
"""Return the number of parameters of a model"""
|
||||
total_params = 0
|
||||
for param_name, param in model.state_dict().items():
|
||||
if ("embed" in param_name) or ("lm_head" in param_name):
|
||||
continue
|
||||
# print(param_name, param.numel())
|
||||
total_params += param.numel()
|
||||
return total_params
|
||||
|
||||
|
||||
def estimate_parameters(config):
|
||||
"""Estimate the number of parameters of a model given its config, should be equal to `num_parameters(model)`"""
|
||||
# embedding parameters
|
||||
embedding_params = config.vocab_size * config.dim_model
|
||||
self_attn_params = 4 * config.dim_model * config.dim_head * config.num_heads * config.num_layers
|
||||
ff_params = 3 * config.dim_model * config.dim_ff * config.num_layers
|
||||
layernorm_parameters = 2 * config.dim_model * config.num_layers
|
||||
output_layernorm_parameters = config.dim_model
|
||||
positional_bias = 32
|
||||
total_params = (
|
||||
embedding_params
|
||||
+ self_attn_params
|
||||
+ ff_params
|
||||
+ layernorm_parameters
|
||||
+ output_layernorm_parameters
|
||||
+ positional_bias
|
||||
)
|
||||
|
||||
params_without_embeddings = total_params - embedding_params
|
||||
|
||||
return total_params, params_without_embeddings
|
||||
|
||||
|
||||
def get_flops_per_token(config):
|
||||
"""An estimated version of pfdays per token, i.e., the 6N in equation: Computation = 6 N * D."""
|
||||
|
||||
_, N = estimate_parameters(config)
|
||||
|
||||
logger.info(">>>>>> pfdays_per_token >>>> {:,.0f}".format(N))
|
||||
|
||||
# evaluating a forward pass
|
||||
C_forward = 6 * N # + 2 * n_layer * n_ctx * d_model
|
||||
# unit = 10**15 * 3600 * 24
|
||||
|
||||
# C_forward = C_forward / unit
|
||||
|
||||
return C_forward
|
|
@ -1,34 +0,0 @@
|
|||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def van_der_corput(n, base=2):
|
||||
"""Generate the n-th value in the Van der Corput sequence."""
|
||||
vdc, denom = 0, 1
|
||||
while n:
|
||||
denom *= base
|
||||
n, remainder = divmod(n, base)
|
||||
vdc += remainder / denom
|
||||
return vdc
|
||||
|
||||
|
||||
def van_der_corput_sampling_gen(vdc_values):
|
||||
"""Generator function for sampling indices based on weights using the Van der Corput sequence."""
|
||||
|
||||
def gen(weights, vdc_value):
|
||||
cdf = np.cumsum(weights)
|
||||
sample = np.searchsorted(cdf, vdc_value)
|
||||
return sample
|
||||
|
||||
sample_index = 0
|
||||
# Pre-generate Van der Corput sequence
|
||||
max_samples = 100000 # or any number that you find suitable
|
||||
|
||||
while True:
|
||||
# Generate the next value in the Van der Corput sequence
|
||||
vdc_value = vdc_values[sample_index % max_samples]
|
||||
# Generate a sample index based on the Van der Corput value and the CDF
|
||||
yield partial(gen, vdc_value=vdc_value)
|
||||
sample_index += 1
|
|
@ -0,0 +1,3 @@
|
|||
cd FM9G-V
|
||||
pip install -r requirements.txt
|
||||
python chat_model.py
|
|
@ -0,0 +1,21 @@
|
|||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
transformers==4.31.0
|
||||
tokenizers>=0.12.1,<0.14
|
||||
sentencepiece==0.1.99
|
||||
shortuuid
|
||||
peft==0.4.0
|
||||
bitsandbytes==0.41.0
|
||||
pydantic<2,>=1
|
||||
markdown2[all]
|
||||
numpy
|
||||
scikit-learn==1.2.2
|
||||
gradio==3.35.2
|
||||
gradio_client==0.2.9
|
||||
requests
|
||||
httpx==0.24.0
|
||||
uvicorn
|
||||
fastapi
|
||||
einops==0.6.1
|
||||
einops-exts==0.0.4
|
||||
timm==0.9.8
|
Binary file not shown.
After Width: | Height: | Size: 128 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{ "data_source_name": "laion_coco", "data_source_weight": 200 },
|
||||
{ "data_source_name": "cc12m", "data_source_weight": 15 },
|
||||
{ "data_source_name": "cc3m", "data_source_weight": 30 },
|
||||
{ "data_source_name": "coco", "data_source_weight": 5 },
|
||||
{ "data_source_name": "vg", "data_source_weight": 5 }
|
||||
]
|
|
@ -0,0 +1,3 @@
|
|||
[
|
||||
{ "data_source_name": "pretrain_eval_eval", "data_source_weight": 1 }
|
||||
]
|
|
@ -0,0 +1,3 @@
|
|||
[
|
||||
{ "data_source_name": "pretrain_eval_train", "data_source_weight": 1 }
|
||||
]
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"train_micro_batch_size_per_gpu": 16,
|
||||
"gradient_accumulation_steps": 16,
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 5e-5,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.98
|
||||
],
|
||||
"weight_decay": 0.01
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 1e-6,
|
||||
"warmup_max_lr": 1e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"initial_scale_power": 10,
|
||||
"auto_cast": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2
|
||||
},
|
||||
"steps_per_print": 50,
|
||||
"gradient_clipping": 1.0
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"train_micro_batch_size_per_gpu": 4,
|
||||
"gradient_accumulation_steps": 16,
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 1e-5,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.98
|
||||
],
|
||||
"weight_decay": 0.01
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 1e-6,
|
||||
"warmup_max_lr": 1e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"initial_scale_power": 10,
|
||||
"auto_cast": true
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2
|
||||
},
|
||||
"steps_per_print": 50,
|
||||
"gradient_clipping": 1.0
|
||||
}
|
|
@ -3,25 +3,14 @@
|
|||
"dropout_p": 0.0,
|
||||
"eps": 1e-05,
|
||||
"half": true,
|
||||
"half_type": "bf16",
|
||||
"use_flash_attn": true,
|
||||
"flash_attn_mask_shape": "2d",
|
||||
"dim_model": 4096,
|
||||
"dim_ff": 14336,
|
||||
"dim_ff": 11008,
|
||||
"dim_head": 128,
|
||||
"num_heads": 32,
|
||||
"num_kv_heads": 32,
|
||||
"num_layers": 32,
|
||||
"activate_fn": "silu",
|
||||
"init_std": 0.10,
|
||||
"scale": false,
|
||||
"scale_emb": 12,
|
||||
"scale_depth": -1,
|
||||
"model_type": "fm9g",
|
||||
"architectures": [
|
||||
"FM9GForCausalLM"
|
||||
],
|
||||
"qk_norm": false,
|
||||
"tie_lm_head": false,
|
||||
"ffn_gated": true
|
||||
"scale": false
|
||||
}
|
|
@ -119687,10 +119687,10 @@
|
|||
"𠳐"
|
||||
"𥻗"
|
||||
"𬉼"
|
||||
"<|im_start|>"
|
||||
"<|im_end|>"
|
||||
"<pad_2>"
|
||||
"<pad_3>"
|
||||
"<pad_4>"
|
||||
"<pad_5>"
|
||||
"<pad_6>"
|
||||
"<image>"
|
||||
"</image>"
|
||||
"<ref>"
|
||||
"</ref>"
|
||||
"<box>"
|
||||
"</box>"
|
||||
"<quad>"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,231 @@
|
|||
import io
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import base64
|
||||
|
||||
import os.path as op
|
||||
import torch.utils.data as torch_data
|
||||
|
||||
from PIL import Image
|
||||
from typing import List, Iterator
|
||||
from muffin.data.tsv_file import TSVFile
|
||||
from muffin.data.data_processors import register_data_processor
|
||||
from vis_fm9g.dataset.itembuilder import ItemBuilder
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
class MultimodalQADataset(torch_data.Dataset):
|
||||
def __init__(self, qa_file, question_process):
|
||||
'''
|
||||
qa_file: jsonl file that each line is a dict like {
|
||||
'image': b64img,
|
||||
'question': question_text
|
||||
}
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.qa_file = qa_file
|
||||
self.qa_data = [json.loads(line) for line in open(self.qa_file)]
|
||||
if isinstance(self.qa_data[0], list):
|
||||
self.qa_data = self.qa_data[0] # unwrap one-line json question file
|
||||
|
||||
self.question_process = question_process
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.qa_data[index]
|
||||
|
||||
img_b64 = item['image']
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert('RGB')
|
||||
|
||||
raw_question = item['question']
|
||||
question_text = self.question_process(raw_question)
|
||||
return {
|
||||
'image': image,
|
||||
'raw_question': raw_question,
|
||||
'question': question_text
|
||||
}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.qa_data)
|
||||
|
||||
|
||||
class SingleDataSourceDataset(torch_data.Dataset):
|
||||
def __init__(self, ds_name, item_builder: ItemBuilder, data_dir, tsv_filenames: List[str], intent='sft') -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.filenames = tsv_filenames
|
||||
self.ds_name = ds_name
|
||||
|
||||
self.sizes = []
|
||||
for filename in self.filenames:
|
||||
try:
|
||||
size = int(filename[:-4].split('-')[-1])
|
||||
except:
|
||||
raise ValueError(f'TSV Data File {filename} is not valid, last component separated by `-` must be the number of sample in this file')
|
||||
self.sizes.append(size)
|
||||
|
||||
self.file_border_index = []
|
||||
self.prepare_border_index()
|
||||
|
||||
self.item_builder = item_builder
|
||||
self.files = self.filenames[:]
|
||||
self.intent = intent
|
||||
|
||||
|
||||
def prepare_border_index(self):
|
||||
self.file_border_index = [0]
|
||||
|
||||
temp_sum = 0
|
||||
for size in self.sizes:
|
||||
temp_sum += size
|
||||
self.file_border_index.append(temp_sum)
|
||||
|
||||
|
||||
def get_file_idx_and_row_idx(self, index):
|
||||
found = False
|
||||
file_idx = -1
|
||||
|
||||
for border_idx, border in enumerate(self.file_border_index):
|
||||
if index < border:
|
||||
file_idx = border_idx - 1
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
raise ValueError(f'Index {index} out of range for {self.size_sum} border markers')
|
||||
|
||||
offset = self.file_border_index[file_idx]
|
||||
row_idx = index - offset
|
||||
return file_idx, row_idx
|
||||
|
||||
def __len__(self):
|
||||
return self.file_border_index[-1]
|
||||
|
||||
def __getitem__(self, index):
|
||||
file_idx, row_idx = self.get_file_idx_and_row_idx(index)
|
||||
try:
|
||||
sample = self.fetch_sample(file_idx, row_idx)
|
||||
item = self.item_builder.build_item(sample)
|
||||
except:
|
||||
logger.warning(f"data fetch error")
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
return item
|
||||
|
||||
def fetch_sample(self, file_idx, row_idx):
|
||||
file = self.files[file_idx]
|
||||
if isinstance(file, str):
|
||||
self.prepare_file(file_idx)
|
||||
file = self.files[file_idx]
|
||||
|
||||
assert isinstance(file, TSVFile), f'Expecting TSVFile but get {file} as {type(file)}'
|
||||
|
||||
# tsv line as tuple
|
||||
sample = file[row_idx]
|
||||
ds_name, *values = sample
|
||||
|
||||
# data dict
|
||||
sample = register_data_processor[self.ds_name](*values, intent=self.intent)
|
||||
|
||||
if row_idx + 1 == len(file):
|
||||
del file
|
||||
self.files[file_idx] = self.filenames[file_idx]
|
||||
|
||||
return sample
|
||||
|
||||
def prepare_file(self, idx):
|
||||
filename = self.filenames[idx]
|
||||
file = TSVFile(op.join(self.data_dir, filename))
|
||||
self.files[idx] = file
|
||||
|
||||
|
||||
class IterableSingleDataSourceDataset(torch_data.IterableDataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
class MultiDataSourceDataset(torch_data.Dataset):
|
||||
def __init__(self, data_sources: List[SingleDataSourceDataset], data_source_weights: List[int]):
|
||||
super().__init__()
|
||||
|
||||
self.ds_list = data_sources
|
||||
|
||||
self.sum_weight = sum(data_source_weights)
|
||||
self.ds_weights = data_source_weights
|
||||
for weight in self.ds_weights:
|
||||
assert isinstance(weight, int), 'weight must be integer'
|
||||
|
||||
self.offset2ds = {}
|
||||
self.offset2wt = {}
|
||||
self.offset2pd = {}
|
||||
self.prepare_offset2ds()
|
||||
|
||||
ds_loops = []
|
||||
for ds, wt in zip(self.ds_list, self.ds_weights):
|
||||
ds_loop = len(ds) // wt
|
||||
ds_loops.append(ds_loop)
|
||||
max_loop = max(ds_loops)
|
||||
self.size = max_loop * self.sum_weight
|
||||
|
||||
def prepare_offset2ds(self):
|
||||
offset = 0
|
||||
for ds, weight in zip(self.ds_list, self.ds_weights):
|
||||
pd = offset
|
||||
for _ in range(weight):
|
||||
self.offset2ds[offset] = ds
|
||||
self.offset2wt[offset] = weight
|
||||
self.offset2pd[offset] = pd
|
||||
offset += 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
n_loop = index // self.sum_weight
|
||||
offset = index % self.sum_weight
|
||||
|
||||
ds = self.offset2ds[offset]
|
||||
ds_inner_idx = n_loop * self.offset2wt[offset] + offset - self.offset2pd[offset]
|
||||
ds_inner_idx = ds_inner_idx % len(ds)
|
||||
|
||||
return ds[ds_inner_idx]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
class IterableMultiDataSourceDataset(torch_data.IterableDataset):
|
||||
def __init__(self, data_sources, data_source_weights):
|
||||
super().__init__()
|
||||
|
||||
self.ds_list = data_sources
|
||||
|
||||
sum_weight = sum(data_source_weights)
|
||||
self.ds_weights = [x / sum_weight for x in data_source_weights]
|
||||
|
||||
self.ds_consumption = []
|
||||
self.ds_sizes = [len(ds) for ds in self.ds_list]
|
||||
|
||||
def __next__(self):
|
||||
ds_idx = numpy.random.choice(range(len(self.ds_list)), 1, p=self.ds_weights)[0]
|
||||
data_source = self.ds_list[ds_idx]
|
||||
|
||||
self.ds_consumption[ds_idx] += 1
|
||||
if self.ds_consumption[ds_idx] % self.ds_sizes[ds_idx] == 0:
|
||||
self.report_consumption()
|
||||
|
||||
sample = next(data_source)
|
||||
return sample
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return sum(self.ds_sizes)
|
||||
|
||||
def report_consumption(self):
|
||||
for ds, consumption, size in zip(self.ds_list, self.ds_consumption, self.ds_sizes):
|
||||
print(f'Data {ds} consumption: {consumption / size:.2f} epoch', flush=True)
|
||||
|
||||
|
|
@ -0,0 +1,309 @@
|
|||
import io
|
||||
import json
|
||||
from typing import Dict, Tuple, List, Any
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.data import default_collate
|
||||
|
||||
from utils.logger import init_logger
|
||||
from vis_fm9g.dataset.utils import convert_data_to_id
|
||||
from vis_fm9g.dataset.utils import convert_conversation_data_to_id
|
||||
from vis_fm9g.dataset.utils import pad
|
||||
import random
|
||||
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
||||
from vis_fm9g.utils.constants import usr_indicator, bot_indicator
|
||||
from vis_fm9g.dataset.prompts import caption_zh, caption_en
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
|
||||
logger = init_logger()
|
||||
|
||||
def is_contain_chinese(check_str):
|
||||
"""
|
||||
判断字符串中是否包含中文
|
||||
:param check_str: {str} 需要检测的字符串
|
||||
:return: {bool} 包含返回True, 不包含返回False
|
||||
"""
|
||||
for ch in check_str:
|
||||
if u'\u4e00' <= ch <= u'\u9fff':
|
||||
return True
|
||||
|
||||
|
||||
def maybe_select_text(raw_text):
|
||||
candidates = raw_text.split('<cap_sep>')
|
||||
return random.choice(candidates)
|
||||
|
||||
|
||||
def maybe_parse_json(raw_text: str):
|
||||
# VG raw
|
||||
if raw_text.startswith('[{') and raw_text.endswith('}]'):
|
||||
try:
|
||||
data = json.loads(raw_text)
|
||||
text_list = [x['phrase'] for x in data if x['height'] > 160 and x['width'] > 160]
|
||||
if len(text_list) == 0:
|
||||
return max(data, key=lambda x: len(x['phrase'].split()))['phrase']
|
||||
else:
|
||||
return random.choice(text_list)
|
||||
except:
|
||||
return raw_text
|
||||
else:
|
||||
return raw_text
|
||||
|
||||
def clean_text(raw_text):
|
||||
text = raw_text.replace('<PERSON>', '')
|
||||
text = maybe_parse_json(maybe_select_text(text))
|
||||
return text
|
||||
|
||||
|
||||
def check_text_valid(raw_text):
|
||||
if pd.isna(raw_text):
|
||||
return False
|
||||
if not is_contain_chinese(raw_text) and len(raw_text.split()) <= 3:
|
||||
return False
|
||||
if '<img' in raw_text or '<a href' in raw_text:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_image_placeholder(tokenizer, query_len, use_im_start_end=False):
|
||||
if use_im_start_end:
|
||||
return tokenizer.im_start + tokenizer.unk_token * query_len + tokenizer.im_end
|
||||
else:
|
||||
return tokenizer.unk_token * query_len
|
||||
|
||||
|
||||
class ItemBuilder():
|
||||
def __init__(self, transform=None):
|
||||
self.transform = transform
|
||||
|
||||
def build_item(self, data):
|
||||
if self.transform is not None:
|
||||
return self.transform(data)
|
||||
return data
|
||||
|
||||
|
||||
# --------------------- FM9G ---------------------
|
||||
class FM9GBuilder(ItemBuilder):
|
||||
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, skip_overlength=False):
|
||||
super().__init__(transform)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = max_len
|
||||
self.skip_overlength = skip_overlength
|
||||
|
||||
def convert_data(self, inp_dicts: List[Dict], raw_data):
|
||||
res = []
|
||||
for inp_dict in inp_dicts:
|
||||
input_ids, context = convert_data_to_id(self.tokenizer, data=inp_dict)
|
||||
if len(input_ids) > self.max_len:
|
||||
if self.skip_overlength:
|
||||
if random.random() > 0.95:
|
||||
logger.warn(f"overlength={len(input_ids)}, raw_inp={inp_dict}, skip data")
|
||||
else:
|
||||
logger.warn(f"overlength={len(input_ids)}, skip data")
|
||||
continue
|
||||
|
||||
input_ids = input_ids[: self.max_len]
|
||||
context = context[: self.max_len]
|
||||
|
||||
res.append({
|
||||
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
|
||||
'context': torch.from_numpy(context).unsqueeze(0),
|
||||
'raw_data': raw_data,
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
def convert_conversation_data(self, conversation_list: List[List]):
|
||||
res = []
|
||||
for conversation in conversation_list:
|
||||
input_ids, context, raw = convert_conversation_data_to_id(self.tokenizer, data=conversation, predict_roles={bot_indicator})
|
||||
if len(input_ids) > self.max_len:
|
||||
if self.skip_overlength:
|
||||
if random.random() > 0.95:
|
||||
logger.warn(f"overlength={len(input_ids)}, raw_inp={conversation}, skip data")
|
||||
else:
|
||||
logger.warn(f"overlength={len(input_ids)}, skip data")
|
||||
continue
|
||||
|
||||
input_ids = input_ids[: self.max_len]
|
||||
context = context[: self.max_len]
|
||||
|
||||
res.append({
|
||||
'input_ids': torch.from_numpy(input_ids).unsqueeze(0),
|
||||
'context': torch.from_numpy(context).unsqueeze(0),
|
||||
'raw_data': raw,
|
||||
})
|
||||
return res
|
||||
|
||||
def build_image_bound(self, res, images):
|
||||
return_res = []
|
||||
if isinstance(images, List) and len(images) > 0:
|
||||
images = torch.stack(images)
|
||||
for r in res:
|
||||
# r['input_ids'] (1, len)
|
||||
image_start_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_start])[0]
|
||||
# 跳过 im_start
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(r['input_ids'][0] == self.tokenizer.encoder[self.tokenizer.im_end])[0]
|
||||
|
||||
if len(image_start_tokens) != len(image_end_tokens) or len(image_start_tokens) > len(images):
|
||||
continue
|
||||
|
||||
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
|
||||
|
||||
r['pixel_values'] = images[:len(image_start_tokens)]
|
||||
r['image_bound'] = image_bound
|
||||
return_res.append(r)
|
||||
|
||||
return return_res
|
||||
|
||||
|
||||
def build_item(self, data):
|
||||
NotImplementedError("build_item is not implemented.")
|
||||
|
||||
|
||||
class FM9GImageTextBuilder(FM9GBuilder):
|
||||
def __init__(self, tokenizer: FM9GTokenizer, max_len, transform=None, query_len=64, min_resolution=0, skip_overlength=False):
|
||||
super().__init__(tokenizer, max_len, transform, skip_overlength)
|
||||
self.query_len = query_len
|
||||
self.min_resolution = min_resolution
|
||||
|
||||
def build_item(self, data):
|
||||
text = data['conversations']
|
||||
image = data['image']
|
||||
source = data.get('metainfo', {}).get('origin_dataset', 'unk')
|
||||
|
||||
image = self.transform(image)
|
||||
|
||||
raw_data = {'text': text}
|
||||
image_placeholder = get_image_placeholder(self.tokenizer, self.query_len, use_im_start_end=True)
|
||||
messages = []
|
||||
for i in range(len(text)):
|
||||
role = text[i]['from']
|
||||
role = usr_indicator if role == 'human' else bot_indicator
|
||||
value = self.tokenizer.escape(text[i]['value'])
|
||||
if '<image>' in value:
|
||||
value = value.replace('<image>', image_placeholder)
|
||||
messages.append((role, value))
|
||||
res = self.convert_conversation_data([messages])
|
||||
self.build_image_bound(res, images=[image])
|
||||
for r in res:
|
||||
r['source'] = source
|
||||
|
||||
return res[0]
|
||||
|
||||
class FM9GCollater:
|
||||
def __init__(self, tokenizer: FM9GTokenizer, max_len: int, unpad: bool = False, unilm: bool = False):
|
||||
self.tokenizer = tokenizer
|
||||
self._max_length = max_len
|
||||
self._unpad = unpad
|
||||
self._unilm = unilm
|
||||
self.pad_keys = ['input_ids', 'context']
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_cnt = len(batch)
|
||||
if self._unpad: # for flash_attention cuda
|
||||
max_length = self._max_length * batch_cnt
|
||||
batch_size = 1
|
||||
else:
|
||||
max_length = self._max_length
|
||||
batch_size = batch_cnt
|
||||
|
||||
inputs = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
context_origin = np.zeros((batch_size, max_length), dtype=np.int8)
|
||||
context = np.zeros((batch_size, max_length), dtype=np.int8)
|
||||
tgt = np.full((batch_size, max_length), -100, dtype=np.int32)
|
||||
|
||||
spans = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
length = np.zeros((batch_size,), dtype=np.int32)
|
||||
position_ids = np.zeros((batch_size, max_length), dtype=np.int32)
|
||||
|
||||
if self._unpad: # for flash_attention cuda, force batch_size=1
|
||||
flatten_input_ids = np.concatenate([batch[i]['input_ids'][0] for i in range(batch_cnt)], axis=0)
|
||||
flatten_context = np.concatenate([batch[i]['context'][0] for i in range(batch_cnt)], axis=0)
|
||||
instance_length = flatten_input_ids.shape[0]
|
||||
inputs[0, : instance_length] = flatten_input_ids
|
||||
context_origin[0, : instance_length] = flatten_context
|
||||
length[0] = instance_length
|
||||
if self._unilm:
|
||||
context[0, : instance_length] = flatten_context
|
||||
# flatten batch
|
||||
_spans = [list(np.cumsum([batch[i]['input_ids'][0].shape[0] for i in range(batch_cnt)]))]
|
||||
|
||||
else:
|
||||
for i in range(batch_cnt):
|
||||
instance_length = batch[i]['input_ids'][0].shape[0]
|
||||
inputs[i, :instance_length] = batch[i]['input_ids'][0]
|
||||
context_origin[i, : instance_length] = batch[i]['context'][0]
|
||||
length[i] = instance_length
|
||||
if self._unilm:
|
||||
context[i, :instance_length] = batch[i]['context'][0]
|
||||
_spans = [[batch[i]['input_ids'][0].shape[0]] for i in range(batch_cnt)]
|
||||
|
||||
|
||||
# cu_seqlens 和 max_seqlen 在 flash_attention cuda 时需要
|
||||
if _spans[0][-1] != max_length:
|
||||
cu_seqlens = np.array([0] + _spans[0] + [max_length], dtype=np.int32)
|
||||
else:
|
||||
cu_seqlens = np.array([0] + _spans[0], dtype=np.int32)
|
||||
max_seqlen = int(np.max(cu_seqlens[1:] - cu_seqlens[:-1]))
|
||||
|
||||
raw_data_list: List[Any] = [batch[i]['raw_data'] for i in range(batch_cnt)]
|
||||
source_list: List[Any] = [batch[i].get('source', 'unk') for i in range(batch_cnt)]
|
||||
|
||||
for i in range(batch_size):
|
||||
instance_length = length[i]
|
||||
span_begin = 0
|
||||
for span_id, span_end in enumerate(_spans[i]):
|
||||
spans[i, span_begin: span_end] = span_id
|
||||
position_ids[i, span_begin:span_end] = np.arange(span_end - span_begin)
|
||||
span_begin = span_end
|
||||
for j in range(instance_length):
|
||||
idx = inputs[i][j]
|
||||
if j > 1:
|
||||
if context_origin[i][j] == 0:
|
||||
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
|
||||
tgt[i, j - 1] = idx
|
||||
if context_origin[i][j] == 1 and context_origin[i][j-1] == 0:
|
||||
if idx != self.tokenizer.bos_id and inputs[i][j - 1] != self.tokenizer.eos_id:
|
||||
tgt[i, j - 1] = self.tokenizer.eos_id
|
||||
|
||||
|
||||
data = {}
|
||||
# image
|
||||
if 'pixel_values' in batch[0]:
|
||||
if self._unpad:
|
||||
data['pixel_values'] = [torch.vstack([i['pixel_values'] for i in batch])]
|
||||
else:
|
||||
data['pixel_values'] = [i['pixel_values'] for i in batch]
|
||||
|
||||
|
||||
# image_bound
|
||||
if 'image_bound' in batch[0]:
|
||||
if self._unpad:
|
||||
image_bounds = []
|
||||
for i in range(batch_cnt):
|
||||
offset = _spans[0][i-1] if i > 0 else 0
|
||||
image_bounds.append(batch[i]['image_bound'] + offset)
|
||||
data['image_bound'] = [torch.vstack(image_bounds)]
|
||||
else:
|
||||
data['image_bound'] = [i['image_bound'] for i in batch]
|
||||
|
||||
data['input_ids'] = torch.from_numpy(inputs)
|
||||
data['context'] = torch.from_numpy(context) > 0
|
||||
data['length'] = torch.from_numpy(length)
|
||||
data['spans'] = torch.from_numpy(spans)
|
||||
data['cu_seqlens'] = torch.from_numpy(cu_seqlens)
|
||||
data['max_seqlen'] = max_seqlen
|
||||
data['position_ids'] = torch.from_numpy(position_ids)
|
||||
data['target'] = torch.from_numpy(tgt)
|
||||
data['raw_data'] = raw_data_list
|
||||
data['source'] = source_list
|
||||
|
||||
return data
|
|
@ -0,0 +1,24 @@
|
|||
caption_en = [
|
||||
'Describe the image concisely',
|
||||
'Provide a brief description of the given image',
|
||||
'Offer a succinct explanation of the picture presented',
|
||||
'Summarize the visual content of the image',
|
||||
'Share a conciseinter pretation of the image provided',
|
||||
'Present a compact description of the photo’s key features',
|
||||
'Relay a brief and clear account of the picture shown',
|
||||
'Render a clear and concise summary of the photo',
|
||||
'Write a terse but informative summary of the picture',
|
||||
'Create a compact narrative representing the image presented',
|
||||
]
|
||||
|
||||
caption_zh = [
|
||||
'简明扼要地描述图像',
|
||||
'提供给定图像的简短描述',
|
||||
'对所示的图片进行简要的解释',
|
||||
'总结图像的视觉内容',
|
||||
'对所提供的图像进行简要的解释',
|
||||
'简明扼要并清楚地说明所示图片',
|
||||
'对这张照片作一个简明扼要的总结',
|
||||
'写一篇简洁但内容丰富的图片摘要',
|
||||
'创造一个紧凑的叙事来代表所呈现的图像',
|
||||
]
|
|
@ -0,0 +1,170 @@
|
|||
import importlib.machinery
|
||||
import importlib.util
|
||||
import types
|
||||
from typing import Any, Set
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from torch.utils.data import BatchSampler
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from vis_fm9g.tokenizer.fm9g_tokenizer import FM9GTokenizer
|
||||
from vis_fm9g.utils.constants import SYSTEM
|
||||
|
||||
FM9GInputType = Union[str, Dict[str, "FM9GInputType"]]
|
||||
|
||||
class _TransformFuncDict(TypedDict):
|
||||
loader: importlib.machinery.SourceFileLoader
|
||||
module: types.ModuleType
|
||||
last_m: float
|
||||
|
||||
|
||||
class FM9GBatch(TypedDict):
|
||||
inputs: NDArray[np.int32]
|
||||
length: NDArray[np.int32]
|
||||
context: NDArray[np.bool_]
|
||||
sample_ids: NDArray[np.int32]
|
||||
spans: NDArray[np.int32]
|
||||
target: NDArray[np.int32]
|
||||
task_ids: NDArray[np.int32]
|
||||
task_names: List[str]
|
||||
raw_data: List[Any]
|
||||
|
||||
|
||||
def convert_data_to_id(tokenizer: FM9GTokenizer, data: Any):
|
||||
"""
|
||||
data: {
|
||||
'input': xxx,
|
||||
'output': xxx
|
||||
}
|
||||
"""
|
||||
input_ids = tokenizer.encode(data["input"])
|
||||
output_ids = tokenizer.encode(data["output"])
|
||||
ids = [tokenizer.bos_id] + input_ids + output_ids + [tokenizer.eos_id]
|
||||
ids = np.array(ids, dtype=np.int32)
|
||||
context = np.zeros((ids.shape[0],), dtype=np.int8)
|
||||
|
||||
# context = 0 需要 target
|
||||
context[: len(input_ids) + 1] = 1
|
||||
return ids, context
|
||||
|
||||
|
||||
def convert_conversation_data_to_id(tokenizer: FM9GTokenizer, data: Any, predict_roles: Set):
|
||||
"""
|
||||
predict_roles: {'<AI>'}
|
||||
data: [
|
||||
('<用户>', xxxx),
|
||||
('<AI>', xxxx)
|
||||
]
|
||||
"""
|
||||
assert (set([i[0] for i in data]) & predict_roles)
|
||||
|
||||
|
||||
if SYSTEM:
|
||||
system = tokenizer.bos_token + SYSTEM + '\n'
|
||||
else:
|
||||
system = tokenizer.bos_token
|
||||
sys_idx = tokenizer.encode(system)
|
||||
ret = system
|
||||
|
||||
input_ids = [sys_idx] if sys_idx else []
|
||||
context = [np.ones((len(sys_idx),), dtype=np.int8)]
|
||||
|
||||
for idx, (role, message) in enumerate(data):
|
||||
prefix = role
|
||||
# 最后一句加上 eos
|
||||
if idx == len(data)-1:
|
||||
message = message + tokenizer.eos_token
|
||||
|
||||
prefix_ids = tokenizer.encode(prefix)
|
||||
message_ids = tokenizer.encode(message)
|
||||
|
||||
input_ids.append(prefix_ids)
|
||||
|
||||
input_ids.append(message_ids)
|
||||
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
|
||||
|
||||
if role in predict_roles:
|
||||
context.append(np.zeros((len(message_ids),), dtype=np.int8))
|
||||
else:
|
||||
context.append(np.ones((len(message_ids),), dtype=np.int8))
|
||||
|
||||
ret += (prefix + message)
|
||||
|
||||
ids = np.hstack(input_ids)
|
||||
context = np.hstack(context)
|
||||
|
||||
return ids, context, ret
|
||||
|
||||
|
||||
def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
|
||||
items = []
|
||||
if isinstance(orig_items[0][key], list):
|
||||
assert isinstance(orig_items[0][key][0], torch.Tensor)
|
||||
for it in orig_items:
|
||||
for tr in it[key]:
|
||||
items.append({key: tr})
|
||||
else:
|
||||
assert isinstance(orig_items[0][key], torch.Tensor)
|
||||
items = orig_items
|
||||
|
||||
batch_size = len(items)
|
||||
shape = items[0][key].shape
|
||||
dim = len(shape)
|
||||
assert dim <= 3
|
||||
if max_length is None:
|
||||
max_length = 0
|
||||
max_length = max(max_length, max(item[key].shape[-1] for item in items))
|
||||
min_length = min(item[key].shape[-1] for item in items)
|
||||
dtype = items[0][key].dtype
|
||||
|
||||
if dim == 1:
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
elif dim == 2:
|
||||
if max_length == min_length:
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||
else:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||
|
||||
for i, item in enumerate(items):
|
||||
if dim == 2:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]):] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0])] = item[key][0].clone()
|
||||
elif dim == 3:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]):, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class SkipBatchSampler(BatchSampler):
|
||||
"""
|
||||
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_sampler, skip_batches=0):
|
||||
self.batch_sampler = batch_sampler
|
||||
self.skip_batches = skip_batches
|
||||
self.first_epoch = True
|
||||
|
||||
def __iter__(self):
|
||||
for index, samples in enumerate(self.batch_sampler):
|
||||
if index >= self.skip_batches and self.first_epoch:
|
||||
yield samples
|
||||
self.first_epoch = False
|
||||
|
||||
@property
|
||||
def total_length(self):
|
||||
return len(self.batch_sampler)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_sampler) - self.skip_batches
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
batch_size = logits.size()[0]
|
||||
if top_p > 0.0:
|
||||
logits = logits.view(batch_size, -1).contiguous()
|
||||
for index in range(len(logits)):
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits[index].view(-1), descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[index][indices_to_remove] = filter_value
|
||||
|
||||
logits = logits.view(batch_size, -1).contiguous()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_repetition_penalty(
|
||||
logits,
|
||||
batch_size,
|
||||
num_beams,
|
||||
prev_output_tokens,
|
||||
repetition_penalty,
|
||||
start_idx=None,
|
||||
end_idx=None,
|
||||
window_size=None,
|
||||
):
|
||||
# only conduct repetition penalty for the output
|
||||
assert repetition_penalty >= 1, "repetition penalty coefficient should >= 1"
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
for i in range(batch_size * num_beams):
|
||||
if start_idx is None or end_idx is None:
|
||||
output_tokens = prev_output_tokens[i].tolist()
|
||||
else:
|
||||
if end_idx >= start_idx:
|
||||
if window_size:
|
||||
output_tokens = prev_output_tokens[i][
|
||||
max(start_idx, end_idx + 1 - window_size): end_idx + 1
|
||||
].tolist()
|
||||
else:
|
||||
output_tokens = prev_output_tokens[i][start_idx: end_idx + 1].tolist()
|
||||
else:
|
||||
output_tokens = []
|
||||
for previous_token in set(output_tokens):
|
||||
# if score < 0 then repetition penalty has to
|
||||
# multiplied to reduce the previous token probability
|
||||
if logits[i, previous_token] < 0:
|
||||
logits[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
logits[i, previous_token] /= repetition_penalty
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_len = max_len
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.n_hyp = n_hyp
|
||||
self.hyp = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.hyp)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
|
||||
if len(self) < self.n_hyp or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.n_hyp:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs, cur_len):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if len(self) < self.n_hyp:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / cur_len**self.length_penalty
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue