Compare commits

..

19 Commits

Author SHA1 Message Date
xiaonans b0d030d0de [fix] fix rope op test failing 2024-04-23 13:51:10 +08:00
xiaonans d000f9750c add shape information to the kvcache attention operator 2024-04-11 14:52:39 +08:00
xiaonans 4a5b9572bb add test scripts for llama2 and 9G models 2024-04-10 16:23:02 +08:00
xiaonans 159642d6ae merge master 2024-04-10 10:03:11 +08:00
xiaonans c01e64db50 rope and attention ops support multiple batchs/sequences. 2024-04-09 09:16:42 +08:00
xiaonans eb3a2d123d accelerate cuda attention 2024-03-28 09:07:30 +08:00
xiaonans 4bdd33522b accelerate cuda fp32 matmul 2024-03-26 11:37:54 +08:00
xiaonans 0740d26f43 clean up 2024-03-21 10:17:06 +08:00
xiaonans fc3d38f80e attention support fp16 2024-03-20 14:56:15 +08:00
xiaonans d43364ac60 inter-block communication is fp16 2024-03-19 11:21:14 +08:00
xiaonans db053e32a4 kv register is fp16 2024-03-18 17:25:57 +08:00
xiaonans 1e797d4ffe cache is fp16 2024-03-18 15:51:19 +08:00
xiaonans 80412ae162 fix bugs when blocksize==64 2024-03-18 15:31:52 +08:00
xiaonans 83be7fa373 fix bugs in rmsnorm op 2024-02-20 10:59:53 +08:00
xiaonans 0f1c04d864 add fp16 support to silu cuda op 2024-02-19 11:39:21 +08:00
xiaonans 936797b960 support rmsnorm 2024-02-08 14:58:47 +08:00
xiaonans 17bd98d453 modify rope op 2024-02-06 17:04:05 +08:00
xiaonans 8cc6af0a83 modify code to pass the cuda_all_reduce test 2024-02-06 10:53:32 +08:00
xiaonans c04910f118 [feature] add cudagraph support 2024-02-05 16:19:58 +08:00
37 changed files with 1594 additions and 1164 deletions

3
.gitmodules vendored
View File

@ -13,6 +13,3 @@
[submodule "example"] [submodule "example"]
path = examples/NNmodel path = examples/NNmodel
url = git@github.com:wanghailu0717/NNmodel.git url = git@github.com:wanghailu0717/NNmodel.git
[submodule "examples/distributed/onnxsim_large_model"]
path = examples/distributed/onnxsim_large_model
url = git@github.com:luchangli03/onnxsim_large_model.git

@ -1 +1 @@
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77 Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98

View File

@ -1,7 +1,5 @@
# 分布式脚本 # 分布式脚本
## 英伟达平台运行方式
#### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx #### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx
使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。 使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。
@ -17,23 +15,3 @@ python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
```bash ```bash
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4 python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
``` ```
## 寒武纪平台运行方式
**将上述运行脚本 `run_pytorch.py` 以及 `cuda_launch.py` 针对寒武纪平台做了相应的适配,具体见 `run_pytorch_mlu.py` 以及 `bang_launch.py`。**
#### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx
使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。
```bash
python run_pytorch_mlu.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
```
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
#### 2. 运行InfiniTensor分布式脚本
```bash
python bang_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
```

View File

@ -1,249 +0,0 @@
import argparse
import torch
import torch_mlu
from transformers import BertModel, BertConfig
from transformers import GPT2Model, GPT2Config
from transformers import OPTModel, OPTConfig
from transformers import AlbertModel, AlbertConfig
from transformers import LlamaModel, LlamaConfig
import time
import numpy as np
import onnx
import sys
import os
from onnx.external_data_helper import convert_model_to_external_data
from onnxsim import simplify
def parse_args():
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
parser.add_argument(
"--model", type=str, choices=["gpt2", "bert", "opt", "llama", "albert"], required=True, help="model type"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--export_onnx",
type=str,
nargs="?",
default=None,
const="./",
help="whether and where to export onnx file",
)
parser.add_argument(
"--type", type=str, choices=["fp32", "fp16", "tf32"], required=True, help="model data type"
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.model,
args.batch_size,
args.length,
args.export_onnx,
args.type
)
def get_model(modelname):
match modelname:
case "albert":
model = AlbertModel.from_pretrained("albert/albert-base-v2")
voc_size = AlbertConfig().vocab_size
case "bert":
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
voc_size = BertConfig().vocab_size
case "gpt2":
model = GPT2Model.from_pretrained("GPT2")
voc_size = GPT2Config().vocab_size
case "opt":
model = OPTModel.from_pretrained("facebook/opt-125m")
voc_size = OPTConfig().vocab_size
case "llama":
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
voc_size = LlamaConfig().vocab_size
case _:
raise KeyError(modelname)
model = model.eval()
return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len, dtype="fp32"):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
np.save("./data/input_0", data)
inputs = torch.from_numpy(data).to("mlu")
torch_model = torch_model.to("mlu")
if dtype == "fp16":
torch_model = torch_model.half()
n_iter = 20
with torch.no_grad():
for _ in range(10):
outputs = torch_model(inputs)
torch.mlu.synchronize()
begin = time.time()
with torch.no_grad():
for _ in range(n_iter):
torch.mlu.synchronize()
outputs = torch_model(inputs)
torch.mlu.synchronize()
torch.mlu.synchronize()
end = time.time()
avg_time = (end - begin) / n_iter
outputs = outputs.last_hidden_state.to("cpu")
print("outputs abs mean:", abs(np.array(outputs)).mean())
print(f"average time: {avg_time}")
# torch.mlu.memory.empty_cache()
np.save("./data/output", np.array(outputs))
print("Save input & output into ./data.")
def export_onnx(modelname, model, data, path, extern=False, dtype="fp32"):
data = data.to("mlu")
model = model.to("mlu")
if dtype == "fp16":
model = model.half()
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
if modelname != "llama":
# use onnxsim to simplify
onnx_model = onnx.load(path)
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
# onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
assert check
add_value_info_for_constants(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
if extern:
extern_path = path.replace('.onnx', '.pb')
if os.path.exists(extern_path):
os.remove(extern_path)
extern_path = extern_path.split("/")[-1]
convert_model_to_external_data(
onnx_model,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(onnx_model, path)
else:
# use third party tool to simplify llama
# reference: https://github.com/luchangli03/onnxsim_large_model/
sys.path.append("onnxsim_large_model")
from onnx_utils import set_onnx_input_shape
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
in_model_path = path
out_model_path = path
if not out_model_path:
out_model_path = in_model_path[:-5] + ".sim.onnx"
if os.path.isdir(out_model_path):
out_model_path = os.path.join(out_model_path, os.path.basename(in_model_path))
onnx_model = onnx.load(in_model_path)
print(f"load model from {in_model_path} success")
size_th_bytes = 1024 * 1024
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
print(f"compress model success")
onnx_model = set_onnx_input_shape(onnx_model, "")
tensor_size_threshold = f"1024KB"
skipped_optimizers = []
skipped_optimizers.append("eliminate_duplicate_initializer")
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
tensor_size_threshold=tensor_size_threshold)
if not check:
raise ValueError(f"simplify compressed model {in_model_path} failed")
print(f"simplify model success")
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
print(f"uncompress model success")
add_value_info_for_constants(onnx_model)
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
def add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def main():
torch.backends.mlu.matmul.allow_tf32 = False
torch.backends.cnnl.allow_tf32 = False
modelname, batchsize, seqlen, export_path, dtype = parse_args()
if dtype == "tf32":
torch.backends.mlu.matmul.allow_tf32 = True
else:
os.environ["CAMBRICON_TF32_OVERRIDE"] = "0"
model, voc_size = get_model(modelname)
if export_path is not None:
filename = "{}_{}_{}_{}.onnx".format(modelname, batchsize, seqlen, dtype)
path = os.path.join(export_path, filename)
if not os.path.exists(path):
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
export_onnx(modelname, model, param, path, True, dtype)
else:
print("Onnx path exists, skipping export.")
run_pytorch(model, voc_size, batchsize, seqlen, dtype)
if __name__ == "__main__":
main()

View File

@ -1,39 +1,35 @@
import sys
sys.path.append('../')
import argparse import argparse
import os import os
import time import time
import multiprocessing as mp import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend from pyinfinitensor.onnx import OnnxStub, backend
import onnx import onnx
from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path from onnx.shape_inference import infer_shapes_path
import numpy as np import numpy as np
from parallel_opt import parallel_model from parallel_opt import parallel_model
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor") parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument( parser.add_argument(
"--nproc_per_node", type=int, default=1, help="number of processes per node" "--nproc_per_node", type=int, default=2, help="number of processes per node"
) )
parser.add_argument( parser.add_argument(
"--name", type=str, default="test", help="name of this instance." "--name", type=str, default="test", help="name of this instance."
) )
parser.add_argument( parser.add_argument(
"--model", type=str, required=True, help="path to the ONNX model file." "--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
help="path to the ONNX model file."
) )
parser.add_argument("--batch_size", type=int, default=1, help="batch size.") parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.") parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument( parser.add_argument(
"--gen_std", "--gen_std",
default=False,
action="store_true", action="store_true",
help="whether to generate the standard results.", help="whether to generate the standard results.",
) )
parser.add_argument(
"--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type"
)
args = parser.parse_args() args = parser.parse_args()
print("arg setting: ", args) print("arg setting: ", args)
return ( return (
@ -44,46 +40,39 @@ def parse_args():
args.batch_size, args.batch_size,
args.length, args.length,
args.gen_std, args.gen_std,
args.type,
) )
def run_model(model, runtime, world_size=1, rank=0, n=10, data_type="default"): def run_model(model, runtime, world_size=1, rank=0, n=10):
stub = OnnxStub(model, runtime, matmul_compute_type=data_type) stub = OnnxStub(model, runtime)
load_inputs(stub, world_size, rank) load_inputs(stub, world_size, rank)
# stub.tune() # stub.tune()
stub.run() stub.run()
# get outputs # get outputs
time.sleep(0.01)
outputs = next(stub.outputs.values().__iter__()).copyout_numpy() outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench # bench
begin = time.time()
for _ in range(n): for _ in range(n):
stub.run() stub.run()
begin = time.time()
for _ in range(n * 2):
stub.run()
end = time.time() end = time.time()
avg_time = (end - begin) / (n * 2) avg_time = (end - begin) / n
print(f"average time: {avg_time}") print(f"average time: {avg_time}")
return outputs return outputs
def load_inputs(stub, world_size=1, rank=0):
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = np.load(f"./data/input_{i}.npy")
if all(x == y for x,y in zip(input.shape,tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
def run_and_compare(name, model, runtime, world_size=1, rank=0, data_type="default"):
results = np.load(f"./data/output.npy") results = np.load(f"./data/output.npy")
outputs = run_model(model, runtime, world_size, rank, data_type=data_type) outputs = run_model(model, runtime, world_size, rank)
print("outputs abs mean:", abs(outputs).mean()) print("answer argmax:", np.argmax(results))
print("max abs diff:", abs(outputs - results).max()) print("output argmax:", np.argmax(outputs))
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
getDiff(results, outputs)
def start_worker( def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
): ):
dist_name = name + "_dist" dist_name = name + "_dist"
model = parallel_model(model, world_size, rank) model = parallel_model(model, world_size, rank)
@ -96,7 +85,7 @@ def start_worker(
save_as_external_data=True, save_as_external_data=True,
location=extern_path, location=extern_path,
) )
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.BangRuntime(local_rank) runtime = backend.BangRuntime(local_rank)
# print("init comm") # print("init comm")
runtime.init_comm( runtime.init_comm(
@ -104,12 +93,13 @@ def start_worker(
world_size, world_size,
rank, rank,
) )
run_and_compare(name, model, runtime, world_size, rank, data_type) run_and_compare(name, model, runtime, world_size, rank)
def start_single(name, model, data_type): def start_single(name, model):
runtime = backend.BangRuntime(0) runtime = backend.BangRuntime(0)
run_and_compare(name, model, runtime, data_type=data_type) run_and_compare(name, model, runtime)
def generate_input_output(model): def generate_input_output(model):
os.makedirs(os.path.dirname("./data/"), exist_ok=True) os.makedirs(os.path.dirname("./data/"), exist_ok=True)
@ -142,36 +132,55 @@ def generate_input_output(model):
np.save(f"./data/output", output) np.save(f"./data/output", output)
def load_inputs(stub, world_size=1, rank=0):
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = np.load(f"./data/input_{i}.npy")
if all(x == y for x,y in zip(input.shape,tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
def getDiff(base, test):
absolute_diff = np.abs(np.subtract(base, test))
max_absolute_diff = np.max(absolute_diff)
baseCopy = base.astype(np.float64).ravel()
testCopy = test.astype(np.float64).ravel()
upValue = np.sum(np.abs(baseCopy - testCopy))
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
max_relative_diff = upValue / downValue
print(f"Max absolute difference: {max_absolute_diff}\n"
f"Max relative difference: {max_relative_diff}")
return max_absolute_diff, max_relative_diff
def main(): def main():
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, data_type = parse_args() nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
data_type = "default" if data_type == "fp32" else data_type
model = onnx.load(model_path) model = onnx.load(model_path)
# generate standart output # generate standart output
if gen_std: if gen_std:
print(f"generate standard data for {name}.") print("Generate inputs and outputs.")
# a small vocabulary size to fit all LLM. p = mp.Process(target=generate_input_output, args=[model])
generate_input_output(model) p.start()
p.join()
return return
if nproc_per_node == 1: # run single process.
# run single process. # use standalone process to isolate cuda.
# use standalone process to isolate bang. print("run model by single MLU.")
print("run model by single MLU.") p = mp.Process(target=start_single, args=(name, model))
# p = mp.Process(target=start_single, args=(name, model, data_type)) p.start()
# p.start() p.join()
# p.join()
start_single(name, model, data_type)
return
# run distributed parallel. # run distributed parallel.
world_size = nnodes * nproc_per_node world_size = nnodes * nproc_per_node
print(f"run model by {world_size} MLU in parallel.") print(f"run model by {world_size} MLUs in parallel.")
workers = [ workers = [
mp.Process( mp.Process(
target=start_worker, target=start_worker,
args=(name, world_size, rank, rank % nproc_per_node, model, data_type), args=(name, world_size, rank, rank % nproc_per_node, model),
) )
for rank in range(world_size) for rank in range(world_size)
] ]

View File

@ -1,14 +0,0 @@
export HF_ENDPOINT=https://hf-mirror.com
models=("bert" "gpt2" "llama")
batch_size=(1 32)
seq_len=(100 500)
nproc=(1 2 4)
for model in "${models[@]}"; do
for bs in "${batch_size[@]}"; do
for len in "${seq_len[@]}"; do
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" --export_onnx ../models/"$model" --export_only
done
done
done

View File

@ -1,280 +0,0 @@
import sys
sys.path.append('../')
import argparse
import os
import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path
import numpy as np
from parallel_opt import parallel_model
from functools import wraps
def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument(
"--nproc_per_node", type=int, default=2, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, choices=["gpt2", "bert", "llama"], help="name of model."
)
parser.add_argument(
"--model", type=str, default="", help="path to the ONNX model file."
)
parser.add_argument(
"--gen_std",
default=False,
action="store_true",
help="whether to generate the standard results.",
)
parser.add_argument(
"--run_single",
default=False,
action="store_true",
help="whether run model with single process with standard inputs"
)
parser.add_argument(
"--input_dir",
default="./",
help="path to save model input data"
)
parser.add_argument(
"--result_dir",
default="./",
help="path to save model standard output"
)
parser.add_argument(
"--internal_model_dir",
default="./",
help="path to save internal onnx model for parallel run"
)
args = parser.parse_args()
# check path, mkdir if not exist
check_exists(args.input_dir)
check_exists(args.result_dir)
check_exists(args.internal_model_dir)
print("arg setting: ", args)
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model,
args.gen_std,
args.run_single,
args.input_dir,
args.result_dir,
args.internal_model_dir
)
"""
utils function for this scripts
"""
def check_exists(path: str):
if not os.path.exists(path):
os.makedirs(path)
def np_assert(base, test, rtol=1e-2, atol=1e-1):
# np.testing.assert_allclose(test, base, rtol, atol)
print("max abs diff:", abs(base - test).max())
"""
Perf wrapper, run function n times
then average
"""
def perf_it(n):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# warmup
for _ in range(n):
func(*args, **kwargs)
t_total = 0
for _ in range(n):
t0 = time.time()
func(*args, **kwargs)
t1 = time.time()
t_total += t1 - t0
avg_time = (t_total) / n
print(f"Avg runtime of {n} time is {avg_time:.6f} seconds")
return avg_time
return wrapper
return decorator
"""
Run InfiniTensor model with Standard input
check=True: check with standard output gen by pytorch
perf=True: run n times to get avg time
"""
def run_model(task_name,
model,
runtime,
world_size=1,
rank=0,
n=10,
check=True,
perf=True):
stub = OnnxStub(model, runtime,
use_naive_allocator=True \
if task_name == "llama" else False)
# load in Onnx model inputs
def load_inputs(stub: OnnxStub):
# check exists
inputs = []
for i, (name, tensor) in enumerate(stub.inputs.items()):
input_path = os.path.join(input_dir, \
f"{task_name}_input_{i}.npy")
print(input_path)
if os.path.exists(input_path):
input = np.load(input_path)
else :
raise KeyError(f"{i} th input of model not exists")
# check shape
if all(x == y for x,y in zip(input.shape, tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
load_inputs(stub)
# stub.tune()
stub.run()
time.sleep(0.01)
output = next(stub.outputs.values().__iter__()).copyout_numpy()
# check output results with standard output
if check:
st_output_path = os.path.join(result_dir, \
f"{task_name}_output.npy")
assert os.path.exists(st_output_path) , \
"standard output not exists"
st_output = np.load(st_output_path)
if np.isnan(output).any():
print("Nan in output")
exit()
np_assert(st_output, output)
# perf
if perf:
@perf_it(n)
def perf_infinitensor(stub: OnnxStub):
stub.run()
perf_infinitensor(stub)
return output
"""
Start a worker in Parallel
"""
def start_worker(name: str,
world_size: int,
rank: int,
local_rank: int,
model: onnx.ModelProto):
dist_name = name + "_dist"
# partial a onnx model to world_size part
model = parallel_model(model, world_size, rank)
onnx.save(model, os.path.join(internal_model_dir, \
f"{dist_name}_rank{rank}.onnx"), save_as_external_data=True)
runtime = backend.KUNLUNRuntime(local_rank)
# print("init comm")
runtime.init_comm(
dist_name,
world_size,
rank,
)
run_model(name, model, runtime, world_size, rank)
"""
generate standard input/output with
sigle card run
"""
def gen_standard(task_name: str, model: onnx.ModelProto):
runtime = backend.KUNLUNRuntime(0)
stub = OnnxStub(model, runtime)
position_id = 0
# generate random input for model
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = tensor.copyout_numpy()
if np.issubdtype(input.dtype, np.integer):
if input.size == 1:
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
else:
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
elif input.dtype == np.bool_:
input = np.random.randint(0,2,size=input.shape) > 0
else:
if i == 0:
input = np.ones(input.shape).astype(input.dtype)
position_id = input.shape[-1] - 1
else:
input = np.random.rand(*input.shape).astype(input.dtype)
tensor.copyin_numpy(input)
np.save(os.path.join(input_dir, \
f"{task_name}_input_{i}.npy"), input)
stub.run()
# print(stub.outputs)
output = next(stub.outputs.values().__iter__()).copyout_numpy()
if np.isnan(output).any():
print("Nan in output")
exit()
np.save(os.path.join(result_dir, f"{task_name}_output.npy"), output)
def main():
global input_dir, result_dir, internal_model_dir
nnodes, nproc_per_node, task_name, \
model_path, gen_std, run_single, \
input_dir, result_dir, internal_model_dir = parse_args()
# load input onnx model
model = onnx.load(model_path)
# generate standart output
if gen_std:
print("Generate inputs and outputs.")
gen_standard(task_name, model)
return
if run_single:
print("Run model by one GPU card.")
runtime = backend.KUNLUNRuntime(0)
run_model(task_name, model, runtime)
return
# run distributed parallel.
world_size = nnodes * nproc_per_node
print(f"Run model by {world_size} GPU in parallel.")
workers = [
mp.Process(
target=start_worker,
args=(task_name, world_size, rank, rank % nproc_per_node, model),
)
for rank in range(world_size)
]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == "__main__":
main()

View File

@ -1,36 +0,0 @@
export HF_ENDPOINT=https://hf-mirror.com
# models=("bert" "gpt2" "llama")
models=("bert" "gpt2")
batch_size=(1 32)
seq_len=(100 500)
nproc=(1 2 4)
results_dir="results"
if [ -d "$results_dir" ]; then
echo "directory ./$results_dir exists"
else
mkdir -p "$results_dir"
echo "mkdir $results_dir, logs saved there"
fi
for model in "${models[@]}"; do
for bs in "${batch_size[@]}"; do
for len in "${seq_len[@]}"; do
# run pytorch model
echo "Run pytorch $model with batch_size=$bs length=$len ."
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" #> results/"$model"_"$bs"_"$len"_pytorch
for n in "${nproc[@]}"; do
# run infinitensor
echo "Run $n parallel infinitensor "$model" with batch_size=$bs and length=$len ."
python kunlun_launch.py --name "$model" --model ../models/"$model"/"$model"_"$bs"_"$len".onnx --nproc_per_node=$n # >> results/"$model"_"$bs"_"$len"_infini
# delete internal files
find ./ -type f -name "*.onnx" -delete
find ./ -type f -name "*.pb" -delete
done
find ./ -type f -name "*.npy" -delete
done
done
done

View File

@ -1,35 +0,0 @@
export HF_ENDPOINT=https://hf-mirror.com
# models=("bert" "gpt2" "llama")
models=("llama")
batch_size=(1 )
seq_len=(100 500)
nproc=(1 2 4)
results_dir="results"
if [ -d "$results_dir" ]; then
echo "directory ./$results_dir exists"
else
mkdir -p "$results_dir"
echo "mkdir $results_dir, logs saved there"
fi
for model in "${models[@]}"; do
for bs in "${batch_size[@]}"; do
for len in "${seq_len[@]}"; do
echo "Run pytorch llama with batch_size="$bs" and length="$len""
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len"
for n in "${nproc[@]}"; do
# run pytorch model
echo "Run infinitensor llama with batch_size="$bs" and length="$len" and nproc="$n"."
python kunlun_launch.py --name llama --model ../models/llama/llama_"$bs"_"$len"_fp32.onnx --nproc_per_node=$n
# delete internal files
find ./ -type f -name "*.onnx" -delete
find ./ -type f -name "*0c" -delete
done
find ./ -type f -name "*.npy" -delete
done
done
done

View File

@ -1,245 +0,0 @@
import argparse
import torch
from transformers import BertModel, BertConfig
from transformers import GPT2Model, GPT2Config
from transformers import OPTModel, OPTConfig
from transformers import LlamaModel, LlamaConfig
import time
import numpy as np
import onnx
import os
import sys
from onnx.external_data_helper import convert_model_to_external_data
from onnxsim import simplify
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def parse_args():
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
parser.add_argument(
"--model", type=str, choices=["gpt2", "bert", "opt", "llama"], required=True, help="model type"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--export_onnx",
type=str,
nargs="?",
default=None,
const="./",
help="whether and where to export onnx file",
)
parser.add_argument(
"--input_dir",
type=str,
default="./",
help="path to save pytorch model input data"
)
parser.add_argument(
"--result_dir",
type=str,
default="./",
help="path to save pytorch model output data"
)
parser.add_argument(
"--export_only",
action="store_true"
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.model,
args.batch_size,
args.length,
args.export_onnx,
args.input_dir,
args.result_dir,
args.export_only
)
def get_model(modelname):
if modelname == "bert":
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
voc_size = BertConfig().vocab_size
elif modelname == "gpt2":
model = GPT2Model.from_pretrained("gpt2")
voc_size = GPT2Config().vocab_size
elif modelname == "opt":
model = OPTModel.from_pretrained("./opt-125m")
voc_size = OPTConfig().vocab_size
elif modelname == "llama":
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
voc_size = LlamaConfig().vocab_size
else :
raise KeyError(modelname)
model = model.eval()
return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len, model_name):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
np.save(os.path.join(input_dir, f"{model_name}_input_0.npy"), data)
inputs = torch.from_numpy(data).to("cuda")
torch_model = torch_model.to("cuda")
n_iter = 10
with torch.no_grad():
for _ in range(10):
outputs = torch_model(inputs)
torch.cuda.synchronize()
begin = time.time()
with torch.no_grad():
for _ in range(n_iter):
torch.cuda.synchronize()
outputs = torch_model(inputs)
#
torch.cuda.synchronize()
torch.cuda.synchronize()
end = time.time()
avg_time = (end - begin) / n_iter
outputs = outputs.last_hidden_state.to("cpu")
print("outputs abs mean:", abs(np.array(outputs)).mean())
print(f"average time: {avg_time}")
torch.cuda.memory.empty_cache()
np.save(os.path.join(result_dir, f"{model_name}_output.npy"), \
np.array(outputs))
print(f"Save input & output as {model_name}_input_0.npy and {model_name}_output.npy")
def export_onnx(model_name, model, data, path, extern=False):
# torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
if model_name != "llama":
onnx_model = onnx.load(path)
onnx_model, check = simplify(onnx_model,
skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
# skipped_optimizers=['fuse_qkv'])
assert check
add_value_info_for_constants(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
if extern:
extern_path = path.replace('.onnx', '.pb')
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
onnx_model,
all_tensors_to_one_file=True,
location=extern_path.split("/")[-1],
size_threshold=1024,
convert_attribute=False,
)
onnx.save(onnx_model, path)
else:
sys.path.append("onnxsim_large_model")
from onnx_utils import set_onnx_input_shape
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
in_model_path = path
out_model_path = in_model_path[:-5] + ".sim.onnx"
onnx_model = onnx.load(in_model_path)
print(f"load model from {in_model_path} success")
size_th_bytes = 1024 * 1024
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
print("compress model success")
onnx_model = set_onnx_input_shape(onnx_model, "")
tensor_size_threshold = f"1024KB"
skipped_optimizers = []
skipped_optimizers.append("eliminate_duplicate_initializer")
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
tensor_size_threshold=tensor_size_threshold)
if not check:
raise ValueError(f"simplify compressed model {in_model_path} failed")
print(f"simplify model success")
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
print(f"uncompress model success")
add_value_info_for_constants(onnx_model)
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
def add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def main():
global input_dir, result_dir
modelname, batchsize, seqlen, \
export_path, input_dir, result_dir, export_only = parse_args()
model, voc_size = get_model(modelname) # pytorch model
if export_path is not None:
os.makedirs(export_path, exist_ok=True)
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
path = os.path.join(export_path, filename)
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
export_onnx(modelname, model, param, path, True) # export pytorch model to onnx model
if export_only:
return
run_pytorch(model, voc_size, batchsize, seqlen, modelname)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,213 @@
import argparse
import os
import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path
import numpy as np
from parallel_opt import parallel_model
st_input_dir = "standard/inputs/"
st_output_dir = "standard/outputs/"
def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument(
"--nproc_per_node", type=int, default=2, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, default="test", help="name of this instance."
)
parser.add_argument(
"--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file."
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--gen_std",
default=False,
action="store_true",
help="whether to generate the standard results.",
)
parser.add_argument(
"--run_single",
default=False,
action="store_true",
help="whether run model with single process with standard inputs"
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model,
args.batch_size,
args.length,
args.gen_std,
args.run_single
)
def run_model(model, runtime, world_size=1, rank=0, n=10):
stub = OnnxStub(model, runtime)
load_inputs(stub, world_size, rank)
# stub.tune()
stub.run()
# get outputs
time.sleep(0.01)
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench
begin = time.time()
for _ in range(n):
stub.run()
end = time.time()
avg_time = (end - begin) / n
print(f"average time: {avg_time}")
return outputs
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
results = np.load(os.path.join(st_output_dir,f"output.npy"))
outputs = run_model(model, runtime, world_size, rank)
print(outputs[:100])
if np.isnan(outputs).any():
print("Nan in output")
print("answer argmax:", np.argmax(results))
print("output argmax:", np.argmax(outputs))
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
getDiff(results, outputs)
def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
):
dist_name = name + "_dist"
model = parallel_model(model, world_size, rank)
extern_path = f"./{dist_name}_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
onnx.save_model(
model,
f"./{dist_name}_rank{rank}.onnx",
save_as_external_data=True,
location=extern_path,
)
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.KUNLUNRuntime(local_rank)
# print("init comm")
runtime.init_comm(
dist_name,
world_size,
rank,
)
run_and_compare(name, model, runtime, world_size, rank)
def start_single(name, model):
runtime = backend.KUNLUNRuntime(0)
run_and_compare(name, model, runtime)
def generate_input_output(model):
runtime = backend.KUNLUNRuntime(0)
stub = OnnxStub(model, runtime)
position_id = 0
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = tensor.copyout_numpy()
if np.issubdtype(input.dtype, np.integer):
if input.size == 1:
# input = np.array([position_id])
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
else:
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
elif input.dtype == np.bool_:
input = np.random.randint(0,2,size=input.shape) > 0
else:
if i == 0:
input = np.ones(input.shape).astype(input.dtype)
position_id = input.shape[-1] - 1
else:
input = np.random.rand(*input.shape).astype(input.dtype)
tensor.copyin_numpy(input)
np.save(os.path.join(st_input_dir, f"input_{i}"), input)
stub.run()
# print(stub.outputs)
time.sleep(0.01)
output = next(stub.outputs.values().__iter__()).copyout_numpy()
print(output[:100])
if np.isnan(output).any():
print("Nan in output")
np.save(os.path.join(st_output_dir, f"output"), output)
def load_inputs(stub, world_size=1, rank=0):
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = np.load(os.path.join(st_input_dir, f"input_{i}.npy"))
if all(x == y for x,y in zip(input.shape,tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
def getDiff(base, test):
absolute_diff = np.abs(np.subtract(base, test))
max_absolute_diff = np.max(absolute_diff)
baseCopy = base.astype(np.float64).ravel()
testCopy = test.astype(np.float64).ravel()
upValue = np.sum(np.abs(baseCopy - testCopy))
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
max_relative_diff = upValue / downValue
print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}")
return max_absolute_diff, max_relative_diff
def main():
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args()
model = onnx.load(model_path)
# generate standart output
if gen_std:
print("Generate inputs and outputs.")
p = mp.Process(target=generate_input_output, args=[model])
p.start()
p.join()
return
# # run single process.
# # use standalone process to isolate cuda.
if run_single:
print("run model by single GPU.")
p = mp.Process(target=start_single, args=(name, model))
p.start()
p.join()
return
# run distributed parallel.
world_size = nnodes * nproc_per_node
print(f"run model by {world_size} GPU in parallel.")
workers = [
mp.Process(
target=start_worker,
args=(name, world_size, rank, rank % nproc_per_node, model),
)
for rank in range(world_size)
]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == "__main__":
main()

@ -1 +0,0 @@
Subproject commit cbcf3fbf985a00494b0f136c92eaccd42031bf65

View File

@ -110,6 +110,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
s_dim = 0 s_dim = 0
elif in_plc.dim == 2: elif in_plc.dim == 2:
s_dim = 1 s_dim = 1
assert s_dim != -1 assert s_dim != -1
assert out_dims[s_dim] % tp_world_size == 0, out_dims assert out_dims[s_dim] % tp_world_size == 0, out_dims
out_dims[s_dim] //= tp_world_size out_dims[s_dim] //= tp_world_size

512
examples/python/test_9G.py Normal file
View File

@ -0,0 +1,512 @@
import os
from pyinfinitensor.onnx import OnnxStub, backend
import numpy as np
import onnx
import torch
from tqdm import tqdm
import onnx_graphsurgeon as gs
import time
import nvtx
import argparse
from mpi4py import MPI
from pytrie import StringTrie
import io
import json
import re
from typing import (
Dict,
List,
IO,
)
parser = argparse.ArgumentParser(description='')
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
parser.add_argument('--layer', dest='n_layers', type=int, default=48)
parser.add_argument("--num_nodes", dest='num_nodes',
type=int, default=1, help="number of nodes")
parser.add_argument("--world_size", dest="world_size",
type=int, default=1, help="")
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
type=int, default=1, help="number of processes per node")
parser.add_argument("--n_max_length", dest="n_max_length",
type=int, default=1024, help="number of processes per node")
parser.add_argument("--vocab_size", dest="vocab_size",
type=int, default=119696, help="vocabulary size")
parser.add_argument("--hidden_size", dest="hidden_size",
type=int, default=4096, help="vocabulary size")
parser.add_argument('--rank', dest='rank', type=int, default=0)
parser.add_argument('--speedup', action='store_true')
parser.add_argument('--no_cudagraph', action='store_true')
parser.add_argument('--fp16', action='store_true')
args = parser.parse_args()
comm = MPI.COMM_WORLD
args.rank = comm.Get_rank()
args.nproc_per_node = comm.Get_size()
args.world_size = args.num_nodes * args.nproc_per_node
ONNX_MODEL_PATH = "/data3/shared/xnsong/9G/dist/9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
weight_path = "9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
model_dir = "/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/cpm9g-11b-sft.pt"
@gs.Graph.register()
def RMSNorm(self, a, b):
return self.layer(op="RMSNorm", inputs=a, outputs=b)
@gs.Graph.register()
def RoPE(self, a, b):
return self.layer(op="RoPE", inputs=a, outputs=b)
@gs.Graph.register()
def AttentionKVCache(self, a, b):
return self.layer(op="AttentionKVCache", inputs=a, outputs=b)
def to_numpy(dict):
ret = dict
if args.fp16:
ret = np.float16(ret)
else:
ret = np.float32(ret)
return ret
def parallel(array, split='replicate'):
if args.world_size > 1 and split == 'partial_column':
return np.hsplit(array, args.world_size)[args.rank]
elif args.world_size > 1 and split == 'partial_row':
return np.vsplit(array, args.world_size)[args.rank]
return array
def generate_onnx(ONNX_MODEL_PATH):
state_dict = torch.load(f'{model_dir}', map_location='cpu')
new_state_dict = {name: param.cpu().numpy()
for name, param in state_dict.items()
}
operators = []
graph = gs.Graph(nodes=operators)
gather_input = gs.Variable(name="gather_input.0", dtype=np.int64, shape=(1,1))
pos_input = gs.Variable(name="pos_input.0", dtype=np.int64, shape=(1,1))
embedding_weight = gs.Constant(name="embedding.weight", values=to_numpy(new_state_dict["input_embedding.weight"]))
gather_output = gs.Variable(name="gather_output.0")
gather = gs.Node(op="Gather", inputs=[embedding_weight, gather_input], outputs=[gather_output])
operators.append(gather)
input = gather_output
graph.inputs=[gather_input, pos_input]
graph.outputs=[]
for i in tqdm(range(args.n_layers)):
# global input
attn_kcache_input = gs.Variable(name="/layers." + str(i) + "/attn/kcache_input", dtype=np.float32, shape=(1,32,1023,128))
attn_vcache_input = gs.Variable(name="/layers." + str(i) + "/attn/vcache_input", dtype=np.float32, shape=(1,32,1023,128))
graph.inputs.append(attn_kcache_input)
graph.inputs.append(attn_vcache_input)
# weight
layernorm_0_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.0/mul_weight",
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".self_att.layernorm_before_attention.weight"]))
attn_qproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/qproj_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_q.weight"]))
, 'partial_column'))
attn_kproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/kproj_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_k.weight"]))
, 'partial_column'))
attn_vproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/vproj_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_v.weight"]))
, 'partial_column'))
attn_outmatmul_input = gs.Constant(name="/layers." + str(i) + "/attn/outmatmul_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.attention_out.weight"]))
, 'partial_row'))
layernorm_1_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.1/mul_weight",
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".ffn.layernorm_before_ffn.weight"]))
ffn_matmul_0_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_0_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_0.weight"]))
, 'partial_column'))
ffn_matmul_1_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_1_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_1.weight"]))
, 'partial_column'))
ffn_matmul_out_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_out_weight",
values=parallel(
np.transpose(
to_numpy(
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_out.weight"]))
, 'partial_row'))
attn_qrope_output = gs.Variable(name="/layers." + str(i) + "/attn/qrope_output")
attn_krope_output = gs.Variable(name="/layers." + str(i) + "/attn/krope_output")
attn_kvcache_output = gs.Variable(name="/layers." + str(i) + "/attn/kvcache_output")
layernorm_0_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.0/mul_output_1")
layernorm_1_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.1/mul_output_1")
attn_qproj_output = gs.Variable(name="/layers." + str(i) + "/attn/qproj_output")
attn_kproj_output = gs.Variable(name="/layers." + str(i) + "/attn/kproj_output")
attn_vproj_output = gs.Variable(name="/layers." + str(i) + "/attn/vproj_output")
attn_outmatmul_output = gs.Variable(name="/layers." + str(i) + "/attn/outmatmul_output")
attn_outadd_output = gs.Variable(name="/layers." + str(i) + "/attn/outadd_output")
ffn_matmul_0_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_0_output")
ffn_silu_output = gs.Variable(name="/layers." + str(i) + "/ffn/silu_output")
ffn_matmul_1_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_1_output")
ffn_mul_output = gs.Variable(name="/layers." + str(i) + "/ffn/mul_output")
ffn_matmul_out_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_out_output")
ffn_add_output = gs.Variable(name="/layers." + str(i) + "/ffn/add_output")
graph.RMSNorm([input, layernorm_0_mul_weight], [layernorm_0_mul_output_1])
attn_qproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_qproj_weight], outputs=[attn_qproj_output])
operators.append(attn_qproj)
attn_kproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_kproj_weight], outputs=[attn_kproj_output])
operators.append(attn_kproj)
attn_vproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_vproj_weight], outputs=[attn_vproj_output])
operators.append(attn_vproj)
graph.RoPE([pos_input, attn_qproj_output], [attn_qrope_output])
graph.RoPE([pos_input, attn_kproj_output], [attn_krope_output])
graph.AttentionKVCache([attn_kcache_input, attn_vcache_input, attn_qrope_output, attn_krope_output, attn_vproj_output, pos_input],[attn_kvcache_output])
attn_outproj = gs.Node(op="MatMul", inputs=[attn_kvcache_output, attn_outmatmul_input], outputs=[attn_outmatmul_output])
operators.append(attn_outproj)
attn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/attn/reducesum_output")
if args.world_size > 1:
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/attn/reducesum",
inputs=[attn_outmatmul_output], outputs=[attn_reduce_sum_output],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum)
attn_outadd = gs.Node(op="Add", inputs=[input, attn_outmatmul_output if args.world_size == 1 else attn_reduce_sum_output], outputs=[attn_outadd_output])
operators.append(attn_outadd)
graph.RMSNorm([attn_outadd_output, layernorm_1_mul_weight], [layernorm_1_mul_output_1])
ffn_matmul_0 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_0_input], outputs=[ffn_matmul_0_output])
operators.append(ffn_matmul_0)
ffn_silu = gs.Node(op="Silu", inputs=[ffn_matmul_0_output], outputs=[ffn_silu_output])
operators.append(ffn_silu)
ffn_matmul_1 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_1_input], outputs=[ffn_matmul_1_output])
operators.append(ffn_matmul_1)
ffn_mul = gs.Node(op="Mul", inputs=[ffn_silu_output, ffn_matmul_1_output], outputs=[ffn_mul_output])
operators.append(ffn_mul)
ffn_matmul_out = gs.Node(op="MatMul", inputs=[ffn_mul_output, ffn_matmul_out_input], outputs=[ffn_matmul_out_output])
operators.append(ffn_matmul_out)
ffn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/ffn/reducesum_output")
if args.world_size > 1:
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/ffn/reducesum",
inputs=[ffn_matmul_out_output], outputs=[ffn_reduce_sum_output],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum)
ffn_add = gs.Node(op="Add", inputs=[attn_outadd_output, ffn_matmul_out_output if args.world_size == 1 else ffn_reduce_sum_output], outputs=[ffn_add_output])
operators.append(ffn_add)
input = ffn_add_output
layernorm_mul_weight = gs.Constant(name="/output/layernorm/mul_weight", values=to_numpy(new_state_dict["encoder.output_layernorm.weight"]))
layernorm_mul_output_1 = gs.Variable(name="/output/layernorm/mul_output_1")
graph.RMSNorm([input, layernorm_mul_weight], [layernorm_mul_output_1])
lm_head_weight = gs.Constant(name="/output/lm_head/weight", values=np.transpose(to_numpy(new_state_dict["lm_head.weight"])))
lm_head_output = gs.Variable(name="/output/lm_head/output")
lm_head = gs.Node(op="MatMul", inputs=[layernorm_mul_output_1, lm_head_weight], outputs=[lm_head_output])
operators.append(lm_head)
if args.fp16:
final_cast_output = gs.Variable(name="/output/cast/output", dtype=np.float32, shape=(1,1,args.vocab_size))
final_cast = gs.Node(op="Cast", inputs=[lm_head_output], outputs=[final_cast_output])
final_cast.attrs["to"] = np.float32
operators.append(final_cast)
graph.outputs.append(final_cast_output)
else:
lm_head_output.dtype=np.float32
lm_head_output.shape=(1,1,args.vocab_size)
graph.outputs.append(lm_head_output)
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True, location=weight_path)
return
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
"""Loads a vocabulary file into a dictionary."""
vocab: Dict[str, int] = {}
reader = io.TextIOWrapper(fp, encoding="utf-8")
for token in reader.readlines():
token = token.strip()
if len(token) == 0:
continue
token = json.loads(token)
vocab[token] = len(vocab)
return vocab
class CPM9GTokenizer(object):
def __init__(self, path):
self.unk_token = "<unk>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
]
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
all_tokens = load_vocab(io.FileIO(path, "rb"))
self.encoder: Dict[str, int] = {}
self._special_encoder: Dict[str, int] = {}
for token, token_id in all_tokens.items():
if token in self._special_token_set:
self._special_encoder[token] = token_id
else:
self.encoder[token] = token_id
self.decoder = {v: k for k, v in self.encoder.items()}
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
self._max_word_len = max([len(x) for x in self.encoder.keys()])
self._len_word_first = {}
for x in self.encoder.keys():
if not x[0] in self._len_word_first:
self._len_word_first[x[0]] = 1
if len(x) > self._len_word_first[x[0]]:
self._len_word_first[x[0]] = len(x)
self.tencoder = StringTrie(self.encoder)
def get_piece(self, text: str) -> str:
if text[0] in self._len_word_first:
text = text[: self._len_word_first[text[0]]]
len_text = len(text)
for i in range(len(text)):
sub = text[: len_text - i]
if sub in self.encoder:
return sub
return text[0]
@property
def vocab_size(self):
return len(self)
@property
def eos_id(self):
return self._special_encoder[self.eos_token]
@property
def bos_id(self):
return self._special_encoder[self.bos_token]
@property
def unk_id(self):
return self._special_encoder[self.unk_token]
def __len__(self):
return len(self.encoder) + len(self._special_encoder)
def tokenize(self, text: str) -> List[str]:
output_tokens: List[str] = []
st = 0
while st < len(text):
piece = self.get_piece(text[st:])
output_tokens.append(piece)
st += len(piece)
return output_tokens
@staticmethod
def escape(text: str) -> str:
return text
@staticmethod
def unescape(text: str) -> str:
return text
def encode(self, text: str, with_bos = True) -> List[int]:
ret = []
if with_bos:
ret.append(self.bos_id)
for x in self.tokenize(text):
if x in self.encoder:
ret.append(self.encoder[x])
else:
ret.extend(self._encode_unicode(x))
return ret
def decode(self, tokens: List[int]):
"""Decode ids into a string."""
ret = []
st = 0
while st < len(tokens):
if tokens[st] in self.decoder:
ret.append(self.decoder[tokens[st]])
st += 1
elif tokens[st] in self._byte_decoder:
first = self._byte_decoder[tokens[st]]
length = 1 if first < 128 else len(re.search('^1+0', bin(first)[2:])[0])-1
code = 0
try:
for j in range(length):
code = code << 8 | self._byte_decoder[tokens[st + j]]
code = int.to_bytes(code, length, "big").decode("utf-8")
ret.append(code)
except:
pass
st = st + length
elif tokens[st] == self.eos_id:
ret.append(self.eos_token)
st += 1
elif tokens[st] == self.bos_id:
ret.append(self.bos_token)
st += 1
else:
ret.append(self.unk_token)
st += 1
return "".join(ret)
def _encode_unicode(self, token):
# wrap unicode encoding into a helper function
ids = []
utf8_id = token.encode("utf-8")
for _id in utf8_id:
ids.append(self._special_encoder[self.byte_list[_id]])
return ids
def next_token(self, text):
# fast next token matching
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
if token is None:
token = text[0]
token_ids = self._encode_unicode(token)
else:
token_ids = [token_id]
return token, token_ids
def start_worker(
world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, query
):
model = onnx.load(ONNX_MODEL_PATH)
runtime = backend.CudaRuntime(local_rank)
if args.nproc_per_node > 1:
runtime.init_comm(
"9g",
world_size,
rank,
)
print("[{}] comm init.".format(rank))
stub = OnnxStub(model, runtime)
print("[{}] stub init.".format(rank))
for i in range(10):
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
print("[{}] stub warmup.".format(rank))
tokenizer = CPM9GTokenizer("/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/vocabs.txt")
query = tokenizer.encode(query)
output_tokens = []
for i in range(len(query)):
q = np.array(query[i])
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
if i == len(query) - 1:
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
q = np.argmax(output)
output_tokens.append(q)
avg_time = 0
count = 0
while i < 1000:
count = count + 1
torch.cuda.synchronize()
with nvtx.annotate("gen {}-th token".format(i), color="red"):
i = i + 1
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
t0 = time.time()
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
t1 = time.time()
avg_time += t1 - t0
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
# print(output)
with nvtx.annotate("argmax".format(i), color="green"):
q = np.argmax(output)
if q == 2:
break
output_tokens.append(q)
avg_time = avg_time / count
print("avg_time_cost =", avg_time*1000, "ms")
text = tokenizer.decode(output_tokens)
return text
if __name__ == "__main__":
comm = MPI.COMM_WORLD
args.rank = comm.Get_rank()
args.nproc_per_node = comm.Get_size()
world_size = args.num_nodes * args.nproc_per_node
if not os.path.exists(ONNX_MODEL_PATH):
print("exporting onnx graph")
generate_onnx(ONNX_MODEL_PATH)
else:
print("will use exsiting onnx graph")
onnx_model = onnx.load(ONNX_MODEL_PATH)
print("data loaded")
#query = '''Beijing is the captial'''
#query = '''什么是PTX'''
#query = '''生病了怎么办?'''
#query = '''Happy'''
query = '''def gcd(a, b):'''
####################################
# infinitensor dist
####################################
# run distributed parallel.
pred = start_worker(world_size, args.rank, args.rank %
args.nproc_per_node, onnx_model, query)
if args.rank == 0:
print("输入:\n\n", query, "\n")
print("输出:", pred)

View File

@ -0,0 +1,491 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
from tqdm import tqdm
import argparse
import torch
import onnx
import onnx_graphsurgeon as gs
import os
import numpy as np
from pyinfinitensor.onnx import OnnxStub, backend
import time
import nvtx
from mpi4py import MPI
parser = argparse.ArgumentParser(description='')
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
parser.add_argument('--layer', dest='n_layers', type=int, default=32)
parser.add_argument("--num_nodes", dest='num_nodes',
type=int, default=1, help="number of nodes")
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
type=int, default=1, help="number of processes per node")
parser.add_argument("--world_size", dest="world_size",
type=int, default=1, help="")
parser.add_argument("--n_max_length", dest="n_max_length",
type=int, default=1024, help="")
parser.add_argument("--vocab_size", dest="vocab_size",
type=int, default=32000, help="vocabulary size")
parser.add_argument("--hidden_size", dest="hidden_size",
type=int, default=4096)
parser.add_argument("--head_size", dest="head_size",
type=int, default=32)
parser.add_argument("--head_dim", dest="head_dim",
type=int, default=128)
parser.add_argument('--rank', dest='rank', type=int, default=0)
parser.add_argument('--no_cudagraph', action='store_true')
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--is_1st_graph', action='store_true')
parser.add_argument('--speedup', action='store_true')
args = parser.parse_args()
comm = MPI.COMM_WORLD
args.rank = comm.Get_rank()
args.nproc_per_node = comm.Get_size()
args.world_size = args.num_nodes * args.nproc_per_node
PRETRAINED_LLAMA_PATH = "/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/"
ONNX_MODEL_PATH = "/data3/shared/xnsong/llama2/" + ("1st" if args.is_1st_graph else "2nd")
ONNX_MODEL_ORIGIN_PATH = ONNX_MODEL_PATH + "/origin/llama2_origin_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_SIM_PATH = ONNX_MODEL_PATH + "/sim/llama2_sim_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_FUSION_PATH = ONNX_MODEL_PATH + "/fusion/llama2_fusion_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_SPECIAL_PATH = ONNX_MODEL_PATH + "/special/llama2_special_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_FP16_PATH = ONNX_MODEL_PATH + "/fp16/llama2_fp16_bs{}_layer{}.onnx".format(
args.batchsize, args.n_layers)
ONNX_MODEL_DIST_PATH = ONNX_MODEL_PATH + "/dist/llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
def parallel_model(onnx_model, world_size, rank):
graph = gs.import_onnx(onnx_model)
tmap = graph.tensors()
for i in range(args.n_layers):
tmap[graph.inputs[2+i*2].name].shape[1] = tmap[graph.inputs[2+i*2].name].shape[1]//world_size
tmap[graph.inputs[3+i*2].name].shape[1] = tmap[graph.inputs[3+i*2].name].shape[1]//world_size
for node in graph.nodes:
if node.name == "/model/layers." + str(i) + "/self_attn/q_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/self_attn/k_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/self_attn/v_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/self_attn/o_proj/MatMul":
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
reduce_sum_output = gs.Variable("reduce_sum_output_" + str(i) + "_0",
dtype=np.float32)
reduce_sum = gs.Node(op="ReduceSum", name="reduce_sum_"+str(i)+"_0",
inputs=node.outputs, outputs=[reduce_sum_output],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum)
next_node = node.outputs[0].outputs[0]
next_node.inputs[1] = reduce_sum_output
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_0" or \
node.name == "/model/layers." + str(i) + "/self_attn/Reshape_1":
node.inputs[1].values = np.array(
[1, 1,
args.head_size//world_size,
args.hidden_size//args.head_size])
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_2":
node.inputs[1] = gs.Constant(name="/model/layers."+str(i)+"/self_attn/vreshape_input",
values=np.array(
[1, 1,
args.head_size//world_size,
args.hidden_size//args.head_size]))
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_3":
node.inputs[1] = gs.Constant(name="/model/layers." + str(i) + "/self_attn/Reshape_3_shape",
values=np.array(
[1, 1, args.hidden_size//world_size]))
elif node.name == "/model/layers." + str(i) + "/mlp/up_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/mlp/gate_proj/MatMul":
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
elif node.name == "/model/layers." + str(i) + "/mlp/down_proj/MatMul":
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
reduce_sum_output_1 = gs.Variable("reduce_sum_output_" + str(i) + "_1",
dtype=np.float32)
reduce_sum_1 = gs.Node(op="ReduceSum", inputs=node.outputs, outputs=[reduce_sum_output_1],
attrs={"noop_with_empty_axes":1, "communicator":0})
graph.nodes.append(reduce_sum_1)
next_node = node.outputs[0].outputs[0]
next_node.inputs[1] = reduce_sum_output_1
# new_out_1 = tmap["/model/layers.0/mlp/down_proj/MatMul_output_0"] #reduce_sum_output
# new_out_1.dtype = np.float32
# new_out_1.shape = [1,1,4096]
# graph.outputs.append(new_out_1)
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def simplify(onnx_model):
graph = gs.import_onnx(onnx_model)
for node in graph.nodes:
if node.op == "Cast":
inp_node = node.i()
inp_node.outputs = node.outputs
node.outputs.clear()
for i in range(args.n_layers):
nodename = "/model/layers." + str(i) + "/self_attn/Add_2"
node = [node for node in graph.nodes if node.name == nodename][0]
inp_node = node.i()
inp_node.outputs = node.outputs
node.outputs.clear()
graph.cleanup().toposort()
return gs.export_onnx(graph)
@gs.Graph.register()
def replace_with_RMSNorm(self, inputs, outputs):
inputs[0].outputs.pop(0)
inputs[0].outputs.pop(0)
for out in outputs:
out.inputs.clear()
return self.layer(op="RMSNorm", inputs=inputs, outputs=outputs, name="rmsnorm")
@gs.Graph.register()
def replace_with_silu(self, inputs, outputs):
for inp in inputs:
inp.outputs.clear()
for out in outputs:
out.inputs.clear()
return self.layer(op="Silu", inputs=inputs, outputs=outputs, name="silu")
@gs.Graph.register()
def replace_with_RoPE(self, a, b):
return self.layer(op="RoPE", inputs=a, outputs=b, name="rope")
@gs.Graph.register()
def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed):
for inp in inputs:
inp.outputs.clear()
for out in outputs:
out.inputs.clear()
for inp in inputs_added:
inputs.append(inp)
for out in outputs_removed:
out.inputs.clear()
return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs, name="attention")
def fusion(model):
graph = gs.import_onnx(model)
tmap = graph.tensors()
tmap["onnx::Reshape_1"].outputs.clear()
inputs = [tmap["/model/layers.0/input_layernorm/Cast_output_0"], tmap["model.layers.0.input_layernorm.weight"]]
rmsnorm_outputs = [tmap["/model/layers.0/input_layernorm/Mul_1_output_0"]]
graph.replace_with_RMSNorm(inputs, rmsnorm_outputs)
for i in range(args.n_layers):
# rotary embedding op
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"].inputs.clear()
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"].inputs.clear()
attn_qreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/qreshape_input",
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
attn_kreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/kreshape_input",
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
attn_qrope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qrope_output")
attn_krope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/krope_output")
attn_qreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qreshape_output")
attn_kreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/kreshape_output")
attn_qreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_0", inputs=[attn_qrope_output, attn_qreshape_input], outputs=[attn_qreshape_output])
attn_kreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_1", inputs=[attn_krope_output, attn_kreshape_input], outputs=[attn_kreshape_output])
attn_qtrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_qreshape_output],
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"]])
attn_ktrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_kreshape_output],
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"]])
graph.nodes.append(attn_qreshape)
graph.nodes.append(attn_kreshape)
graph.nodes.append(attn_qtrans)
graph.nodes.append(attn_ktrans)
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/q_proj/MatMul_output_0"]]
graph.replace_with_RoPE(inputs, [attn_qrope_output])
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/k_proj/MatMul_output_0"]]
graph.replace_with_RoPE(inputs, [attn_krope_output])
# rms-norm op
inputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Cast_output_0"], \
tmap["model.layers." + str(i) + ".post_attention_layernorm.weight"]]
outputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Mul_1_output_0"]]
graph.replace_with_RMSNorm(inputs, outputs)
inputs = [tmap["/model/layers." + str(i+1) + "/input_layernorm/Cast_output_0"] if i != args.n_layers-1 else \
tmap["/model/norm/Cast_output_0"], \
tmap["model.layers." + str(i+1) + ".input_layernorm.weight"] if i != args.n_layers-1 else \
tmap["model.norm.weight"]]
outputs = [tmap["/model/layers."+ str(i+1) + "/input_layernorm/Mul_1_output_0"]] if i != args.n_layers-1 else \
[tmap["/model/norm/Mul_1_output_0"]]
graph.replace_with_RMSNorm(inputs, outputs)
# silu op
inputs = [tmap["/model/layers." + str(i) + "/mlp/gate_proj/MatMul_output_0"]]
outputs = [tmap["/model/layers." + str(i) + "/mlp/act_fn/Mul_output_0"]]
graph.replace_with_silu(inputs, outputs)
inputs = [
tmap["onnx::Concat_" + str((i+1)*2)],
tmap["onnx::Concat_" + str((i+1)*2+1)],
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
outputs = [
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],]
inputs_added = [graph.inputs[1]]
outputs_removed = []
graph.replace_with_attention(
inputs, outputs, inputs_added, outputs_removed)
graph.outputs = [tmap[graph.outputs[0].name]]
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def special_pass(model):
graph = gs.import_onnx(model)
tmap = graph.tensors()
for node in graph.nodes:
if node.op == "Transpose" or node.op == "Reshape":
inp_node = node.i()
inp_node.outputs = node.outputs
node.outputs.clear()
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def convert_to_fp16(model):
graph = gs.import_onnx(model)
for node in graph.nodes:
if node.op == "Gather" and node.name == "/model/embed_tokens/Gather":
node.inputs[0].values = np.float16(node.inputs[0].values)
if node.op == "RMSNorm":
node.inputs[1].values = np.float16(node.inputs[1].values)
if node.op == "MatMul":
node.inputs[1].values = np.float16(node.inputs[1].values)
if node.name == "/lm_head/MatMul":
cast_1_out = gs.Variable(node.name+"_cast_out_output_0", dtype=np.float32, shape=node.outputs[0].shape)
cast_1 = gs.Node(op="Cast", inputs=[node.outputs[0]], outputs=[cast_1_out])
cast_1.attrs["to"] = np.float32
cast_1.name = node.name+"_cast_out_0"
graph.nodes.append(cast_1)
graph.outputs[0] = cast_1_out
node.outputs[0].dtype = np.float16
graph.cleanup(True).toposort()
return gs.export_onnx(graph)
def export_onnx(model: AutoModelForCausalLM):
if not os.path.exists(ONNX_MODEL_ORIGIN_PATH):
print("exporting origin onnx model...")
with torch.no_grad():
param = torch.zeros(
(args.batchsize, model.config.max_position_embeddings-1), dtype=torch.long)
logits = model(param, past_key_values=None)
if not args.is_1st_graph:
param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long)
torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values,
"position_ids": param_kvcache}), \
ONNX_MODEL_ORIGIN_PATH, verbose=False,
do_constant_folding=True,)
else:
position_ids = torch.tile(torch.arange(0, model.config.max_position_embeddings-1), (args.batchsize, 1))
attention_mask = torch.ones((args.batchsize, model.config.max_position_embeddings-1), dtype=torch.bool)
torch.onnx.export(model, (param, {"attention_mask": attention_mask,
"position_ids": position_ids}),\
ONNX_MODEL_ORIGIN_PATH, verbose=False,
do_constant_folding=True,)
print("export origin onnx finished.")
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SIM_PATH):
print("exporting sim onnx model...")
onnx_model = onnx.load(ONNX_MODEL_ORIGIN_PATH)
onnx_model = simplify(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_SIM_PATH, save_as_external_data=True, \
location="llama2_sim_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting sim onnx model finished.")
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_FUSION_PATH):
print("exporting fusion onnx model...")
onnx_model = onnx.load(ONNX_MODEL_SIM_PATH)
onnx_model = fusion(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_FUSION_PATH, save_as_external_data=True, \
location="llama2_fusion_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting fusion onnx model finished.")
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SPECIAL_PATH):
print("exporting special onnx model...")
onnx_model = onnx.load(ONNX_MODEL_FUSION_PATH)
onnx_model = special_pass(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_SPECIAL_PATH, save_as_external_data=True, \
location="llama2_special_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting special onnx model finished.")
if not args.is_1st_graph and args.fp16 and not os.path.exists(ONNX_MODEL_FP16_PATH):
print("exporting fp16 onnx model...")
onnx_model = onnx.load(ONNX_MODEL_SPECIAL_PATH)
onnx_model = convert_to_fp16(onnx_model)
onnx.save(onnx_model, ONNX_MODEL_FP16_PATH, save_as_external_data=True, \
location="llama2_fp16_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
print("exporting fp16 onnx model finished.")
print("world_size =", args.world_size)
if not args.is_1st_graph and args.world_size > 1 and not os.path.exists(ONNX_MODEL_DIST_PATH):
print("exporting dist onnx model...")
onnx_model = onnx.load(ONNX_MODEL_FP16_PATH) if args.fp16 else onnx.load(ONNX_MODEL_SPECIAL_PATH)
onnx_model = parallel_model(onnx_model, args.world_size, args.rank)
onnx.save(onnx_model, ONNX_MODEL_DIST_PATH, save_as_external_data=True, \
location="llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
args.batchsize, args.n_layers,
16 if args.fp16 else 32, args.world_size, args.rank))
print("exporting dist onnx model finished.")
def get_it_logit(onnx_model, input_ids):
# initialization
runtime = backend.CudaRuntime(args.rank)
runtime.init_comm(
"dist",
args.world_size,
args.rank,
)
print("[{}] comm init.".format(args.rank))
stub = OnnxStub(onnx_model, runtime)
print("[{}] stub init.".format(args.rank))
# warm up
for i in range(10):
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
print("[{}] stub warmup.".format(args.rank))
logits = np.zeros((args.batchsize, args.n_max_length, args.vocab_size), dtype=np.float32)
output_ids = np.zeros((args.batchsize, args.n_max_length), dtype=np.int64)
avg_inference_time = 0
t0 = time.time()
for i in tqdm(range(0, args.n_max_length)):
with nvtx.annotate("seq_length = {}".format(i), color="red"):
assert input_ids.shape[0] == args.batchsize
input_id = input_ids[:, i] if i < input_ids.shape[1] else output_ids[:, i-1]
position_id = i*np.ones((args.batchsize, 1), dtype=np.int32)
# copyin input
with nvtx.annotate("[it] copyin", color="blue"):
(list(stub.inputs.items()))[0][1].copyin_int64(
input_id.reshape(-1).tolist())
(list(stub.inputs.items()))[1][1].copyin_int64(
position_id.reshape(-1).tolist())
# run
t10 = time.time()
with nvtx.annotate("[it] run", color="green"):
if args.no_cudagraph:
stub.run()
else:
stub.run_with_cudagraph()
t11 = time.time()
avg_inference_time += (t11 - t10)
# copyout output
if not args.speedup:
with nvtx.annotate("[it] copyout", color="blue"):
logits[:,i, :] = np.array((list(stub.outputs.items()))[0][1].copyout_float()).reshape(args.batchsize, -1)
output_ids[:, i] = np.argmax(logits[:, i, :], -1).astype(np.int64)
t1 = time.time()
if args.rank == 0:
result = "[it] e2e: {} gpus, {} layers, e2e time: {:.2f}s, average inference time: {:.2f}ms"\
.format(args.num_nodes * args.nproc_per_node, args.n_layers, t1-t0, \
avg_inference_time*1000/args.n_max_length)
print(result)
del stub
return output_ids
if __name__ == "__main__":
torch_model = LlamaForCausalLM.from_pretrained(
PRETRAINED_LLAMA_PATH, num_hidden_layers=int(args.n_layers)).eval()
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LLAMA_PATH)
#prompt = "Hey, are you conscious? Can you talk to me?"
#prompt = "What is PTX?"
#prompt = "Tell me a joke."
#prompt = "What are the key principles of smart investing?"
prompt = "What is DeepSpeed?"
prompts=[prompt]*args.batchsize
inputs = tokenizer(prompts, return_tensors="pt")
input_ids = inputs.input_ids
print("prompt ids =", input_ids)
##########################################################
# inference with InfiniTensor
##########################################################
print("exporting onnx...")
export_onnx(torch_model)
print("exporting onnx finished.")
onnx_to_run_path = ONNX_MODEL_DIST_PATH if args.world_size > 1 else \
(ONNX_MODEL_FP16_PATH if args.fp16 else ONNX_MODEL_SPECIAL_PATH)
print("loading onnx", onnx_to_run_path, "...")
onnx_model = onnx.load(onnx_to_run_path)
print("loading onnx finished.")
output_ids_it = get_it_logit(onnx_model, input_ids)
it_output_text = tokenizer.batch_decode(output_ids_it[:, input_ids.shape[-1]:output_ids_it.shape[-1]])
if args.rank == 0:
for i in range(args.batchsize):
print("prompt: ", prompts[i])
print("answer: [it]", it_output_text[i])
##########################################################
# validation with pytorch
##########################################################
"""
generate_ids = torch_model.generate(inputs.input_ids, max_length=args.n_max_length)#, num_beams=4, do_sample=True)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"""
if not args.speedup and not args.is_1st_graph:
kvcache_torch = None
output_ids_pt = torch.zeros(args.batchsize, args.n_max_length).int() # + input_ids.shape[-1] - 1).int()
if args.fp16:
torch_model = torch_model.half()
torch_model = torch_model.cuda()
# print(torch.cuda.memory_summary())
avg_inference_time = 0
with torch.no_grad():
t0 = time.time()
for i in range(args.n_max_length):
input_id = input_ids[:,i] if i < input_ids.shape[1] else out_token
input_id = input_id.view(args.batchsize,1).cuda()
t00 = time.time()
outputs = torch_model(input_id, past_key_values=kvcache_torch)
t01 = time.time()
avg_inference_time += (t01-t00)
logits = outputs['logits']
kvcache_torch = outputs['past_key_values']
out_token = torch.argmax(logits, dim=-1)
output_ids_pt[:, i:i+1] = out_token
t1 = time.time()
avg_inference_time /= args.n_max_length
result = "[pt] e2e time: {:.2f}s, average inference time: {:.2f}ms"\
.format(t1-t0, avg_inference_time*1000)
if args.rank == 0:
print(result)
pt_output_text = tokenizer.batch_decode(output_ids_pt[:,input_ids.shape[-1]:args.n_max_length])
for i in range(args.batchsize):
print("[pt]", args.rank, pt_output_text[i])
if not args.is_1st_graph:
assert(output_ids_it.shape[-1] == args.n_max_length)
np.testing.assert_equal(output_ids_pt[:, input_ids.shape[-1]:args.n_max_length], output_ids_it[:,input_ids.shape[-1]:args.n_max_length])

View File

@ -3,14 +3,17 @@
#include <cstdio> #include <cstdio>
struct AttentionKVCacheMetadata { struct AttentionKVCacheMetadata {
int dimSize[4]; int head_dim;
int stride[4]; int num_heads;
int num_seqs;
int max_kv_seqlen;
}; };
namespace infini { namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, void attention_kvcache_kernel(int dType, void *input_k_cache,
float *input_q, float *input_k, float *input_v, void *input_v_cache, void *input_q, void *input_k,
int *position_id, float *output_matmul, void *input_v, int64_t *position_id,
void *output_matmul,
const AttentionKVCacheMetadata &compMeta, const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp); float *output_O_temp, float *output_sum_temp);

View File

@ -5,8 +5,7 @@
namespace infini { namespace infini {
void rope_kernel(int dType, int *pos, void *input, void *output, int size, void rope_kernel(int dType, int64_t *pos, void *input, void *output,
int dim_model, int dim_head, int hidden_stride, int dim_model, int dim_head, int batchsize, int pos_stride);
int pos_stride);
}; // namespace infini }; // namespace infini

View File

@ -21,7 +21,7 @@ class KUNLUNRuntimeObj : public RuntimeObj {
ctx = xdnn::create_context(); ctx = xdnn::create_context();
// 10GB for Longformer // 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30); // size_t longformerNum = 3lu * (1 << 30);
size_t workspaceSize = 2llu << 30; // 2 GB size_t workspaceSize = 3llu << 30; // 3 GB
KUNLUNPtr wkspacePtr = alloc(workspaceSize); KUNLUNPtr wkspacePtr = alloc(workspaceSize);
workspace = workspace =
make_ref<WorkspaceObj<KUNLUNPtr>>(wkspacePtr, workspaceSize); make_ref<WorkspaceObj<KUNLUNPtr>>(wkspacePtr, workspaceSize);
@ -42,7 +42,7 @@ class KUNLUNRuntimeObj : public RuntimeObj {
KUNLUNPtr alloc(size_t size) override { KUNLUNPtr alloc(size_t size) override {
void *ptr; void *ptr;
checkKUNLUNError( checkKUNLUNError(
xpu_malloc((void **)&ptr, size, XPUMemoryKind::XPU_MEM_HBM)); xpu_malloc_ex((void **)&ptr, size, XPUMemoryKind::XPU_MEM_MAIN));
return ptr; return ptr;
} }
void dealloc(void *ptr) override { xpu_free(ptr); } void dealloc(void *ptr) override { xpu_free(ptr); }

View File

@ -34,8 +34,8 @@ class XcclCommunicatorObj final : public CommunicatorObj {
auto begin = std::chrono::steady_clock::now(); auto begin = std::chrono::steady_clock::now();
while (!std::filesystem::exists(filePath)) { while (!std::filesystem::exists(filePath)) {
auto now = std::chrono::steady_clock::now(); auto now = std::chrono::steady_clock::now();
_IT_ASSERT_2(now < begin + std::chrono::seconds(100), _IT_ASSERT_2(now < begin + std::chrono::seconds(10),
"time limit (100s) exceeded."); "time limit (10s) exceeded.");
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
} }
std::ifstream ifs(filePath, std::ios::binary); std::ifstream ifs(filePath, std::ios::binary);

View File

@ -3,8 +3,7 @@
namespace infini { namespace infini {
/** /**
* @brief Fused Attention with KVCache input operator. All the input and output * @brief Fused Attention with KVCache input operator.
* tensors should have the same rank except for the position_id.
* *
*/ */
class AttentionKVCacheObj : public OperatorObj { class AttentionKVCacheObj : public OperatorObj {
@ -16,12 +15,19 @@ class AttentionKVCacheObj : public OperatorObj {
* *
* @param graph The computation graph that this operator belongs to. * @param graph The computation graph that this operator belongs to.
* @param input_k_cache The k_cache input tensor. * @param input_k_cache The k_cache input tensor.
* Shape: [batchsize, num_heads, k_cache_seq_length, head_dim]
* @param input_v_cache The v_cache input tensor. * @param input_v_cache The v_cache input tensor.
* Shape: [batchsize, num_heads, v_cache_seq_length, head_dim]
* @param input_q The query input tensor. * @param input_q The query input tensor.
* Shape: [batchsize, q_seq_length, model_dim]
* @param input_k The key input tensor. * @param input_k The key input tensor.
* Shape: [batchsize, q_seq_length, model_dim]
* @param input_v The value input tensor. * @param input_v The value input tensor.
* @param position_id The positon id of the query, * Shape: [batchsize, q_seq_length, model_dim]
* @param position_id The positon id of the query.
* Shape: [batchsize, q_seq_length]
* @param output_matmul The query output tensor. * @param output_matmul The query output tensor.
* Shape: [batchsize, q_seq_length, model_dim]
*/ */
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v_cache, Tensor input_q, Tensor input_k,
@ -30,6 +36,10 @@ class AttentionKVCacheObj : public OperatorObj {
OP_CLONE(AttentionKVCacheObj); OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
vector<DataType> inferDataType(const TensorVec &inputs) const override {
return {inputs[2]->getDType()};
};
DataType getDType() const { return getInputs(2)->getDType(); }
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 6; } int numInputs() const override { return 6; }

View File

@ -21,6 +21,10 @@ class RoPEObj : public OperatorObj {
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
DataType getDType() const { return getInputs(1)->getDType(); } DataType getDType() const { return getInputs(1)->getDType(); }
vector<DataType> inferDataType(const TensorVec &inputs) const override {
return {inputs[1]->getDType()};
};
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;

View File

@ -208,16 +208,39 @@ class OnnxStub:
op[1], op[1],
) )
elif node.op_type == "MatMul": elif node.op_type == "MatMul":
tensors[node.output[0]] = self.handler.matmul( if node.input[1] in data.keys() \
tensors[node.input[0]], # input and to_array(data[node.input[1]]).dtype == np.float32 \
tensors[node.input[1]], # weight and 'cuda_runtime' in dir(backend) \
tensors.get(node.output[0]), and tensors[node.input[0]].shape()[0] == 1 \
False, and tensors[node.input[0]].shape()[1] == 1 \
False, and len(tensors[node.input[1]].shape()) == 2 \
None, and node.input[1] in data.keys():
backend.ActType.Linear, data[node.input[1]] = from_array(
matmul_compute_type, np.transpose(to_array(data[node.input[1]])))
) tensors[node.input[1]] = self.handler.tensor(
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
tensors[node.input[1]].dtype())
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
True,
None,
backend.ActType.Linear,
matmul_compute_type,
)
else:
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
matmul_compute_type,
)
elif node.op_type == "Gemm": elif node.op_type == "Gemm":
attributes = _parse_attribute( attributes = _parse_attribute(
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
@ -967,7 +990,7 @@ class OnnxStub:
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
) )
elif node.op_type in ["Constant", "ConstantOfShape"]: elif node.op_type == "Constant":
output_name = node.output[0] output_name = node.output[0]
attributes = _parse_attribute(node) attributes = _parse_attribute(node)
tensor = attributes["value"] tensor = attributes["value"]

View File

@ -199,24 +199,6 @@ class CastCnnl : public BangKernelWithoutConfig {
dim.data())); dim.data()));
NlCastType = CNNL_CAST_UINT32_TO_INT64; NlCastType = CNNL_CAST_UINT32_TO_INT64;
break; break;
case CastType::Float162Float:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_HALF, dim.size(),
dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, dim.size(),
dim.data()));
NlCastType = CNNL_CAST_HALF_TO_FLOAT;
break;
case CastType::Float2Float16:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, dim.size(),
dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_HALF, dim.size(),
dim.data()));
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
break;
default: default:
IT_TODO_HALT(); IT_TODO_HALT();
} }

View File

@ -19,16 +19,14 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
auto inDims = op->getInputs(0)->getDims(); auto inDims = op->getInputs(0)->getDims();
auto fiterDims = op->getInputs(1)->getDims();
auto outDims = op->getOutput()->getDims(); auto outDims = op->getOutput()->getDims();
auto fiterDims = op->getOutput(1)->getDims();
float eps = op->getEps(); float eps = op->getEps();
const int axis = op->getAxis(); const int axis = op->getAxis();
Shape outMeanDims(outDims); cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc;
outMeanDims.erase(outMeanDims.begin() + axis);
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc, outMeanDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
@ -41,23 +39,15 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outDims.size(), outDims.data())); outDims.size(), outDims.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&outMeanDesc));
checkCnnlError(cnnlSetTensorDescriptor(
outMeanDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outMeanDims.size(), outMeanDims.data()));
size_t wsSize; size_t wsSize;
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc, cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
&wsSize); &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
size_t meanSize =
cnnlGetTensorElementNum(outMeanDesc) * op->getDType().getSize();
BangPtr meanData = context->getWorkspace(meanSize);
BangPtr rstdData = context->getWorkspace(meanSize);
cnnlStatus_t stat = cnnlLayerNormForward( cnnlStatus_t stat = cnnlLayerNormForward(
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc, context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData, scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
outMeanDesc, meanData, rstdData); inDesc, NULL, NULL);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -66,13 +66,6 @@ class MatmulCnnl : public BangKernelWithoutConfig {
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB, cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
sizeof(int32_t)); sizeof(int32_t));
std::string computeTypeStr = op->getComputeType();
if (computeTypeStr == "tf32") {
int32_t tf32 = 1;
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_ALLOW_TF32, &tf32,
sizeof(int32_t));
}
cnnlMatMulAlgo_t bmm_algo; cnnlMatMulAlgo_t bmm_algo;
cnnlMatMulAlgoCreate(&bmm_algo); cnnlMatMulAlgoCreate(&bmm_algo);

View File

@ -7,33 +7,37 @@ namespace infini {
class AttentionKVCacheCompute { class AttentionKVCacheCompute {
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata, void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
Tensor tensor) const { Tensor input_v_cache,
int nDims = tensor->getRank(); Tensor position_id) const {
auto strides = tensor->getStride(); int nDims = input_v_cache->getRank();
auto strides = input_v_cache->getStride();
IT_ASSERT(nDims == 4); IT_ASSERT(nDims == 4);
IT_ASSERT(strides.size() == (size_t)nDims); int dim_position_id = position_id->getRank();
for (int i = 0; i < nDims; ++i) { metadata.num_seqs = 1;
metadata.dimSize[i] = tensor->getDims().at(i); for (int i = 0; i < dim_position_id; i++) {
metadata.stride[i] = strides.at(i); metadata.num_seqs *= position_id->getDims().at(i);
} }
metadata.head_dim = input_v_cache->getDims().at(3);
metadata.num_heads = input_v_cache->getDims().at(1);
metadata.max_kv_seqlen = input_v_cache->getDims().at(2);
} }
public: public:
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, void do_compute(int dType, Tensor input_k_cache, Tensor input_v_cache,
Tensor input_k, Tensor input_v, Tensor position_id, Tensor input_q, Tensor input_k, Tensor input_v,
Tensor output_matmul, CudaPtr p_workspace) const { Tensor position_id, Tensor output_matmul,
CudaPtr p_workspace) const {
AttentionKVCacheMetadata metadata; AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache); initAttentionKVCacheMetadata(metadata, input_v_cache, position_id);
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(), attention_kvcache_kernel(
input_v_cache->getRawDataPtr<float *>(), dType, input_k_cache->getRawDataPtr<void *>(),
input_q->getRawDataPtr<float *>(), input_v_cache->getRawDataPtr<void *>(),
input_k->getRawDataPtr<float *>(), input_q->getRawDataPtr<void *>(), input_k->getRawDataPtr<void *>(),
input_v->getRawDataPtr<float *>(), input_v->getRawDataPtr<void *>(),
position_id->getRawDataPtr<int *>(), position_id->getRawDataPtr<int64_t *>(),
output_matmul->getRawDataPtr<float *>(), output_matmul->getRawDataPtr<void *>(), metadata,
metadata, (float *)p_workspace, (float *)p_workspace, (float *)(p_workspace + (1ll << 30)));
(float *)(p_workspace + (1ll << 30)));
} }
}; };
@ -41,15 +45,17 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
public CudaKernelWithoutConfig { public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
IT_ASSERT(_op->getDType() == DataType::Float32); auto op = as<AttentionKVCacheObj>(_op);
int dType = op->getDType().getIndex();
int position_idx_dtype = op->getInputs()[5]->getDTypeIndex();
IT_ASSERT(dType == 1 || dType == 10 || position_idx_dtype == 7);
size_t workspaceSize = 2ll << 30; size_t workspaceSize = 2ll << 30;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
CudaPtr idxWsData = context->getWorkspace(workspaceSize); CudaPtr idxWsData = context->getWorkspace(workspaceSize);
do_compute(_op->getInputs()[0], _op->getInputs()[1], do_compute(dType, op->getInputs()[0], op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3], op->getInputs()[2], op->getInputs()[3], op->getInputs()[4],
_op->getInputs()[4], _op->getInputs()[5], op->getInputs()[5], op->getOutputs()[0], idxWsData);
_op->getOutputs()[0], idxWsData);
} }
}; };

View File

@ -1,171 +1,236 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_attention_kvcache.h" #include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32 #define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE
#define SEQ_UNIT 16 #define SEQ_UNIT 16
#define BLOCKSIZE_2 WARP_SIZE*4
#define MAX_PARTITION 1024
// ASSUME SEQ_LEN OF Q IS 1 template <class T>
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
float* input_v_cache, T* input_v_cache,
float* input_q, T* input_q,
float* input_k, T* input_k,
float* input_v, T* input_v,
int* position_id, int64_t* position_id,
AttentionKVCacheMetadata compMeta, AttentionKVCacheMetadata compMeta,
float* output_O_temp, half* output_O_temp,
float* output_sum_temp) { float* output_sum_temp) {
int seq_length = position_id[0] + 1; int seq_length = position_id[blockIdx.y] + 1;
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT; int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
if(blockIdx.y >= stride) if(blockIdx.z >= stride)
return; return;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
int group_id = threadIdx.x / WARP_SIZE; int parallel_idx = blockIdx.x + blockIdx.y * gridDim.x;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int idx_seq = blockIdx.y * SEQ_UNIT;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) int idx_seq = blockIdx.z * SEQ_UNIT;
return;
float ptr_V[SEQ_UNIT*4]; // V half reg_V[4];
float ptr_K[SEQ_UNIT*4]; // K half reg_K[4];
float ptr_Q[4]; // Q half reg_Q[4];
float ptr_P[SEQ_UNIT] = {0}; float reg_P;
float ptr_O[4] = {0}; float reg_O[4] = {0};
float ptr_sum[1] = {0}; float reg_sum = 0;
float temp[4];
bool is_fp16 = sizeof(T) == 2 ? true : false;
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.head_dim;
// readin Q // readin Q
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)]; if(!is_fp16){
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
// Q*K
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
}
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i ++){ for(int i = 0; i < 4; i += 2){
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i]; (float2 &)temp[i] = (float2 &)input_q[idx_qkv + i*WARP_SIZE];
#pragma unroll *((half2*)(&reg_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
} }
} }
else{
// div sqrt(d) #pragma unroll
#pragma unroll for(int i = 0; i < 4; i += 2){
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { (half2 &)reg_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); }
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
} }
int common_idx = lane_id_x2 + (parallel_idx * compMeta.max_kv_seqlen * compMeta.head_dim);
// softmax
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]); reg_P = 0;
ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.head_dim);
} // readin K & V
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
// * V #pragma unroll
#pragma unroll for(int i = 0; i < 4; i += 2){
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { *((half2*)(&reg_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ *((half2*)(&reg_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
(float4 &)ptr_V[idx_SEQ_UNIT * 4] }
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
} }
else{ else{
(float4 &)ptr_V[idx_SEQ_UNIT * 4] if(!is_fp16){
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; #pragma unroll
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] for(int i = 0; i < 4; i += 2){
= (float4 &)ptr_V[idx_SEQ_UNIT * 4]; (float2 &)temp[i] = (float2 &) input_k[idx_qkv + i*WARP_SIZE];
*((half2*)(&reg_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_K[i]));
(float2 &)temp[i] = (float2 &) input_v[idx_qkv + i*WARP_SIZE];
*((half2*)(&reg_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_V[i]));
}
}
else{
#pragma unroll
for(int i = 0; i < 4; i += 2){
(half2 &)reg_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_K[i]));
(half2 &)reg_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&reg_V[i]));
}
}
} }
// Q*K
#pragma unroll
for (int i = 0; i < 4; i += 2){
(half2 &)reg_K[i] = (half2 &)reg_Q[i] * (half2 &)reg_K[i];
#pragma unroll
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
(half2 &)reg_K[i] += __shfl_xor_sync(0xffffffff, (half2 &)reg_K[i], offset);
}
(float2 &) temp[i] = __half22float2((half2 &)reg_K[i]);
reg_P += (temp[i] + temp[i+1]);
(float2 &) temp[i] = __half22float2((half2 &)reg_V[i]);
}
// div sqrt(d)
reg_P /= sqrt(128.0);
// softmax
reg_P = expf(reg_P);
reg_sum += reg_P;
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i ++) for (int i = 0; i < 4; i ++)
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]); reg_O[i] = fmaf(reg_P, temp[i], reg_O[i]);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i ++) for (int i = 0; i < 4; i ++)
ptr_O[i] /= ptr_sum[0]; reg_O[i] /= reg_sum;
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; #pragma unroll
if(lane_id == 0){ for(int i = 0; i < 4; i += 2)
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; (half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.z * compMeta.head_dim) + (parallel_idx * compMeta.head_dim * stride)] = __float22half2_rn((float2 &)reg_O[i]);
if(lane_id_x2 == 0){
output_sum_temp[blockIdx.z + parallel_idx * stride] = reg_sum;
} }
} }
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
float* output_matmul, template <class T>
__global__ void _attention_kvcache_kernel_128_2(int64_t* position_id,
T* output_matmul,
AttentionKVCacheMetadata compMeta, AttentionKVCacheMetadata compMeta,
float* output_O_temp, half* output_O_temp,
float* output_sum_temp) { float* output_sum_temp) {
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE; int parallel_idx = blockIdx.x;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; int offset = parallel_idx * compMeta.head_dim;
float ptr_O[4] = {0};
float ptr_O_sum[4] = {0};
float ptr_sum = 0;
float ptr_sum_temp;
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT; int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
bool is_fp16 = sizeof(T) == 2 ? true : false;
if(size == 1){
if(!is_fp16){
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
output_matmul[i + offset]
= __half2float(output_O_temp[i + offset]);
}
else{
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
output_matmul[i + offset]
= output_O_temp[i + offset];
}
return;
}
__shared__ float shm_sum_temp[MAX_PARTITION];
__shared__ float shm_sum[WARP_SIZE];
float temp_sum = 0;
#pragma unroll #pragma unroll
for(int i = 0; i < size; i ++){ for(int i = threadIdx.x; i < size; i += blockDim.x){
(float4 &)ptr_O[0] shm_sum_temp[i] = output_sum_temp[i + parallel_idx * size];
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]; temp_sum += shm_sum_temp[i];
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
ptr_sum += ptr_sum_temp;
} }
#pragma unroll #pragma unroll
for(int k = 0; k < 4; k ++) for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum; temp_sum += __shfl_down_sync(0xffffffff, temp_sum, offset);
if(lane_id == 0)
shm_sum[threadIdx.x/WARP_SIZE] = temp_sum;
__syncthreads();
temp_sum = lane_id < (size + WARP_SIZE - 1) / WARP_SIZE ? shm_sum[lane_id] : 0;
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0]; #pragma unroll
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
temp_sum += __shfl_xor_sync(0xffffffff, temp_sum, offset);
temp_sum = __fdividef(1.0f, temp_sum + 1e-6f);
#pragma unroll
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x){
float acc = 0.0f;
for(int j = 0; j < size; j ++){
acc = fma(__half2float(output_O_temp[i + (j * compMeta.head_dim) + offset * size]) * shm_sum_temp[j], temp_sum, acc);
}
if(!is_fp16){
output_matmul[i + offset] = acc;
}
else{
output_matmul[i + offset] = __float2half(acc);
}
}
} }
namespace infini { namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cache,
float *input_q, float *input_k, void *input_q, void *input_k,
float *input_v, int *position_id, float *output_matmul, void *input_v, int64_t *position_id, void *output_matmul,
const AttentionKVCacheMetadata &compMeta, const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp) { float *output_O_temp, float *output_sum_temp) {
IT_ASSERT(compMeta.dimSize[3] == 128); IT_ASSERT(dType == 1 || dType == 10);
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT; int gridsize_y = (compMeta.max_kv_seqlen - 1 + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y); dim3 gridDim(compMeta.num_heads, compMeta.num_seqs, gridsize_y);
dim3 blockDim(BLOCKSIZE, 1); dim3 blockDim(WARP_SIZE, 1);
_attention_kvcache_kernel_128_1 if(dType == 1){
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>> _attention_kvcache_kernel_128_1<float>
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, <<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
compMeta, output_O_temp, output_sum_temp); ((float*)input_k_cache, (float*)input_v_cache, (float*)input_q, (float*)input_k, (float*)input_v,
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<float>
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
0, CUDAStream::getCurrentStream()>>>
(position_id, (float*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
}
else{
_attention_kvcache_kernel_128_1<half>
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
((half*)input_k_cache, (half*)input_v_cache, (half*)input_q, (half*)input_k, (half*)input_v,
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<half>
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
0, CUDAStream::getCurrentStream()>>>
(position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
}
_attention_kvcache_kernel_128_2
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
0, CUDAStream::getCurrentStream()>>>
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
} }
} // namespace infini } // namespace infini

View File

@ -36,7 +36,7 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) { cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
if (cuDataType == CUDA_R_16F) { if (cuDataType == CUDA_R_16F) {
return CUBLAS_COMPUTE_32F_FAST_16F; return CUBLAS_COMPUTE_16F;
} else if (cuDataType == CUDA_R_16BF) { } else if (cuDataType == CUDA_R_16BF) {
return CUBLAS_COMPUTE_32F_FAST_16BF; return CUBLAS_COMPUTE_32F_FAST_16BF;
} else if (cuDataType == CUDA_R_32F) { } else if (cuDataType == CUDA_R_32F) {

View File

@ -18,17 +18,18 @@ class RoPECuda : public CudaKernelWithoutConfig {
const auto &inputShape = input->getDims(); const auto &inputShape = input->getDims();
int nDims = input->getDims().size(); int nDims = input->getDims().size();
int size = input->size();
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2); IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
IT_ASSERT(inputShape[1] == pos->getDims()[1]); IT_ASSERT(inputShape[0] == pos->getDims()[0] &&
inputShape[1] == pos->getDims()[1]);
int position_idx_dtype = op->getInputs()[0]->getDTypeIndex();
int dim_model = inputShape[2]; int dim_model = inputShape[2];
int dim_head = 128; int dim_head = 128; // TODO: get dim_head from the framework
int hidden_stride = dim_model * inputShape[1];
int pos_stride = inputShape[1]; int pos_stride = inputShape[1];
int batchsize = inputShape[0];
const int dType = op->getDType().getIndex(); const int dType = op->getDType().getIndex();
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData, rope_kernel(dType, pos->getRawDataPtr<int64_t *>(), inputData,
size, dim_model, dim_head, hidden_stride, pos_stride); outputData, dim_model, dim_head, batchsize, pos_stride);
} }
}; };

View File

@ -4,13 +4,15 @@
#include "utils/small_array.h" #include "utils/small_array.h"
template <class T> template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, __global__ void _rope_kernel(int64_t* pos, void *in, void *out, int dim_model,
int dim_head, int hidden_stride, int pos_stride) { int dim_head, int batchsize, int pos_stride) {
int batch_id = blockIdx.x; int batch_id = blockIdx.x;
int target_pos = pos[batch_id * pos_stride + blockIdx.y]; int target_pos = pos[batch_id * pos_stride + blockIdx.y];
int ith = blockIdx.z * blockDim.x + threadIdx.x; int ith = blockIdx.z * blockDim.x + threadIdx.x;
int col = ith % dim_head; int col = ith % dim_head;
int offset = batch_id * hidden_stride + blockIdx.y * dim_model; int batch_stride = pos_stride * dim_model;
int offset = batch_id * batch_stride + blockIdx.y * dim_model;
if (ith >= dim_model) if (ith >= dim_model)
return; return;
@ -34,7 +36,7 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
#define CASE(T) \ #define CASE(T) \
_rope_kernel<DT_CUDA<T>::t> \ _rope_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \ <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride); (pos, input, output, dim_model, dim_head, batchsize, pos_stride);
#define SWITCH_DTYPE(DTYPE) \ #define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \
@ -79,10 +81,10 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
} }
namespace infini { namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size, void rope_kernel(int dType, int64_t * pos, void *input, void *output,
int dim_model, int dim_head, int hidden_stride, int pos_stride) { int dim_model, int dim_head, int batchsize, int pos_stride) {
dim3 blocksize = dim3(32,1,1); dim3 blocksize = dim3(32,1,1);
dim3 gridsize = dim3(1, 1, dim_model/32); dim3 gridsize = dim3(batchsize, pos_stride, dim_model/32);
SWITCH_DTYPE(dType) SWITCH_DTYPE(dType)
} }

View File

@ -97,14 +97,11 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
auto aDim = op->getInputs(0)->getDims(); auto aDim = op->getInputs(0)->getDims();
auto bSize = op->getInputs(1)->size(); auto bSize = op->getInputs(1)->size();
auto bDim = op->getInputs(1)->getDims(); auto bDim = op->getInputs(1)->getDims();
auto dtype = op->getDType();
// op input a, b is scalar while aDim and b Dim is empty
if (bDim.size() == 0) { if (bDim.size() == 0) {
bDim.push_back(1); bDim.push_back(1);
} }
if (aDim.size() == 0) {
aDim.push_back(1);
}
if (aSize == bSize) { if (aSize == bSize) {
// Do ElementWise Sub with no broadcast // Do ElementWise Sub with no broadcast
@ -112,9 +109,23 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
(float *)aData, (float *)bData, (float *)aData, (float *)bData,
(float *)cData, aSize)); (float *)cData, aSize));
} else { } else {
checkKUNLUNError(xdnn::broadcast_div<float>( // Do broadcast div
context->KUNLUNHandle(), (float *)aData, (float *)bData, Shape aligned = infer_broadcast(aDim, bDim);
(float *)cData, aDim, bDim)); if (aligned == aDim) {
// BData need to be broadcasted
checkKUNLUNError(xdnn::broadcast_div<float>(
context->KUNLUNHandle(), (float *)aData, (float *)bData,
(float *)cData, aDim, bDim));
} else {
// Use workspace to broadcast aData
KUNLUNPtr wks = context->getWorkspace(bSize * dtype.getSize());
checkKUNLUNError(xdnn::broadcast<float>(
context->KUNLUNHandle(), (float *)aData, (float *)wks, aDim,
bDim));
checkKUNLUNError(xdnn::div<float>(context->KUNLUNHandle(),
(float *)wks, (float *)bData,
(float *)cData, bSize));
}
} }
return; return;
} }

View File

@ -570,7 +570,6 @@ REGISTER_KERNEL(Device::KUNLUN, OpType::Reciprocal, ReciprocalXdnn,
REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, CopyXdnn, "Reshape_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, CopyXdnn, "Reshape_xdnn");
REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, CopyXdnn, "Flatten_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, CopyXdnn, "Flatten_xdnn");
REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, CopyXdnn, "Identity_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, CopyXdnn, "Identity_xdnn");
REGISTER_KERNEL(Device::KUNLUN, OpType::Squeeze, CopyXdnn, "Squeeze_xdnn");
REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, AbsXdnn, "Abs_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, AbsXdnn, "Abs_xdnn");
REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, ATanXdnn, "Atan_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, ATanXdnn, "Atan_xdnn");
REGISTER_KERNEL(Device::KUNLUN, OpType::Log, LogXdnn, "Log_xdnn"); REGISTER_KERNEL(Device::KUNLUN, OpType::Log, LogXdnn, "Log_xdnn");

View File

@ -26,11 +26,12 @@ TEST(RoPE, Cuda) {
cudaRuntime->run(gCuda); cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]); auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
oCpu->printData();
EXPECT_TRUE(oCpu->equalData(vector<float>{ EXPECT_TRUE(oCpu->equalData(vector<float>{
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 0.540302, 0.647906, 0.731761, 0.796458, 0.846009, 0.883756, 0.912396,
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 0.934062, 0.950415, 0.962739, 0.972014, 0.978989, 0.98423, 0.988167,
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 0.991122, 0.99334, 0.995004, 0.996253, 0.99719, 0.997892, 0.998419,
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 0.998815, 0.999111, 0.999333, 0.9995, 0.999625, 0.999719, 0.999789,
1.381773, 1.381773, 1.381773, 1.381773})); 0.999842, 0.999881, 0.999911, 0.999933}));
} }
} // namespace infini } // namespace infini