forked from jiuyuan/InfiniTensor
Compare commits
19 Commits
master
...
kvcache_ba
Author | SHA1 | Date |
---|---|---|
xiaonans | b0d030d0de | |
xiaonans | d000f9750c | |
xiaonans | 4a5b9572bb | |
xiaonans | 159642d6ae | |
xiaonans | c01e64db50 | |
xiaonans | eb3a2d123d | |
xiaonans | 4bdd33522b | |
xiaonans | 0740d26f43 | |
xiaonans | fc3d38f80e | |
xiaonans | d43364ac60 | |
xiaonans | db053e32a4 | |
xiaonans | 1e797d4ffe | |
xiaonans | 80412ae162 | |
xiaonans | 83be7fa373 | |
xiaonans | 0f1c04d864 | |
xiaonans | 936797b960 | |
xiaonans | 17bd98d453 | |
xiaonans | 8cc6af0a83 | |
xiaonans | c04910f118 |
|
@ -13,6 +13,3 @@
|
||||||
[submodule "example"]
|
[submodule "example"]
|
||||||
path = examples/NNmodel
|
path = examples/NNmodel
|
||||||
url = git@github.com:wanghailu0717/NNmodel.git
|
url = git@github.com:wanghailu0717/NNmodel.git
|
||||||
[submodule "examples/distributed/onnxsim_large_model"]
|
|
||||||
path = examples/distributed/onnxsim_large_model
|
|
||||||
url = git@github.com:luchangli03/onnxsim_large_model.git
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
Subproject commit b896cec2dba5b8522b141ac4f89eb43074ee1b98
|
|
@ -1,7 +1,5 @@
|
||||||
# 分布式脚本
|
# 分布式脚本
|
||||||
|
|
||||||
## 英伟达平台运行方式
|
|
||||||
|
|
||||||
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
||||||
|
|
||||||
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
||||||
|
@ -17,23 +15,3 @@ python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
|
||||||
```bash
|
```bash
|
||||||
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
||||||
```
|
```
|
||||||
|
|
||||||
## 寒武纪平台运行方式
|
|
||||||
|
|
||||||
**将上述运行脚本 `run_pytorch.py` 以及 `cuda_launch.py` 针对寒武纪平台做了相应的适配,具体见 `run_pytorch_mlu.py` 以及 `bang_launch.py`。**
|
|
||||||
|
|
||||||
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
|
||||||
|
|
||||||
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python run_pytorch_mlu.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
|
|
||||||
```
|
|
||||||
|
|
||||||
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
|
|
||||||
|
|
||||||
#### 2. 运行InfiniTensor分布式脚本
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python bang_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
|
||||||
```
|
|
|
@ -1,249 +0,0 @@
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import torch_mlu
|
|
||||||
from transformers import BertModel, BertConfig
|
|
||||||
from transformers import GPT2Model, GPT2Config
|
|
||||||
from transformers import OPTModel, OPTConfig
|
|
||||||
from transformers import AlbertModel, AlbertConfig
|
|
||||||
from transformers import LlamaModel, LlamaConfig
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
import onnx
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from onnx.external_data_helper import convert_model_to_external_data
|
|
||||||
from onnxsim import simplify
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", type=str, choices=["gpt2", "bert", "opt", "llama", "albert"], required=True, help="model type"
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
|
||||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--export_onnx",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default=None,
|
|
||||||
const="./",
|
|
||||||
help="whether and where to export onnx file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--type", type=str, choices=["fp32", "fp16", "tf32"], required=True, help="model data type"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
print("arg setting: ", args)
|
|
||||||
return (
|
|
||||||
args.model,
|
|
||||||
args.batch_size,
|
|
||||||
args.length,
|
|
||||||
args.export_onnx,
|
|
||||||
args.type
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(modelname):
|
|
||||||
match modelname:
|
|
||||||
case "albert":
|
|
||||||
model = AlbertModel.from_pretrained("albert/albert-base-v2")
|
|
||||||
voc_size = AlbertConfig().vocab_size
|
|
||||||
case "bert":
|
|
||||||
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
|
||||||
voc_size = BertConfig().vocab_size
|
|
||||||
case "gpt2":
|
|
||||||
model = GPT2Model.from_pretrained("GPT2")
|
|
||||||
voc_size = GPT2Config().vocab_size
|
|
||||||
case "opt":
|
|
||||||
model = OPTModel.from_pretrained("facebook/opt-125m")
|
|
||||||
voc_size = OPTConfig().vocab_size
|
|
||||||
case "llama":
|
|
||||||
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
voc_size = LlamaConfig().vocab_size
|
|
||||||
case _:
|
|
||||||
raise KeyError(modelname)
|
|
||||||
|
|
||||||
model = model.eval()
|
|
||||||
return model, voc_size
|
|
||||||
|
|
||||||
def run_pytorch(torch_model, voc_size, batchsize, len, dtype="fp32"):
|
|
||||||
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
|
||||||
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
|
||||||
np.save("./data/input_0", data)
|
|
||||||
inputs = torch.from_numpy(data).to("mlu")
|
|
||||||
torch_model = torch_model.to("mlu")
|
|
||||||
if dtype == "fp16":
|
|
||||||
torch_model = torch_model.half()
|
|
||||||
|
|
||||||
n_iter = 20
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(10):
|
|
||||||
outputs = torch_model(inputs)
|
|
||||||
torch.mlu.synchronize()
|
|
||||||
begin = time.time()
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(n_iter):
|
|
||||||
torch.mlu.synchronize()
|
|
||||||
outputs = torch_model(inputs)
|
|
||||||
torch.mlu.synchronize()
|
|
||||||
torch.mlu.synchronize()
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
avg_time = (end - begin) / n_iter
|
|
||||||
outputs = outputs.last_hidden_state.to("cpu")
|
|
||||||
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
|
||||||
print(f"average time: {avg_time}")
|
|
||||||
# torch.mlu.memory.empty_cache()
|
|
||||||
np.save("./data/output", np.array(outputs))
|
|
||||||
print("Save input & output into ./data.")
|
|
||||||
|
|
||||||
|
|
||||||
def export_onnx(modelname, model, data, path, extern=False, dtype="fp32"):
|
|
||||||
data = data.to("mlu")
|
|
||||||
model = model.to("mlu")
|
|
||||||
if dtype == "fp16":
|
|
||||||
model = model.half()
|
|
||||||
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
|
||||||
if modelname != "llama":
|
|
||||||
# use onnxsim to simplify
|
|
||||||
onnx_model = onnx.load(path)
|
|
||||||
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
|
|
||||||
# onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
|
||||||
assert check
|
|
||||||
add_value_info_for_constants(onnx_model)
|
|
||||||
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
|
||||||
if extern:
|
|
||||||
extern_path = path.replace('.onnx', '.pb')
|
|
||||||
if os.path.exists(extern_path):
|
|
||||||
os.remove(extern_path)
|
|
||||||
extern_path = extern_path.split("/")[-1]
|
|
||||||
convert_model_to_external_data(
|
|
||||||
onnx_model,
|
|
||||||
all_tensors_to_one_file=True,
|
|
||||||
location=extern_path,
|
|
||||||
size_threshold=1024,
|
|
||||||
convert_attribute=False,
|
|
||||||
)
|
|
||||||
onnx.save(onnx_model, path)
|
|
||||||
else:
|
|
||||||
# use third party tool to simplify llama
|
|
||||||
# reference: https://github.com/luchangli03/onnxsim_large_model/
|
|
||||||
sys.path.append("onnxsim_large_model")
|
|
||||||
from onnx_utils import set_onnx_input_shape
|
|
||||||
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
|
|
||||||
|
|
||||||
in_model_path = path
|
|
||||||
out_model_path = path
|
|
||||||
if not out_model_path:
|
|
||||||
out_model_path = in_model_path[:-5] + ".sim.onnx"
|
|
||||||
if os.path.isdir(out_model_path):
|
|
||||||
out_model_path = os.path.join(out_model_path, os.path.basename(in_model_path))
|
|
||||||
|
|
||||||
onnx_model = onnx.load(in_model_path)
|
|
||||||
print(f"load model from {in_model_path} success")
|
|
||||||
|
|
||||||
size_th_bytes = 1024 * 1024
|
|
||||||
|
|
||||||
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
|
|
||||||
print(f"compress model success")
|
|
||||||
|
|
||||||
onnx_model = set_onnx_input_shape(onnx_model, "")
|
|
||||||
|
|
||||||
tensor_size_threshold = f"1024KB"
|
|
||||||
skipped_optimizers = []
|
|
||||||
skipped_optimizers.append("eliminate_duplicate_initializer")
|
|
||||||
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
|
|
||||||
tensor_size_threshold=tensor_size_threshold)
|
|
||||||
if not check:
|
|
||||||
raise ValueError(f"simplify compressed model {in_model_path} failed")
|
|
||||||
|
|
||||||
print(f"simplify model success")
|
|
||||||
|
|
||||||
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
|
|
||||||
print(f"uncompress model success")
|
|
||||||
|
|
||||||
add_value_info_for_constants(onnx_model)
|
|
||||||
|
|
||||||
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
|
|
||||||
|
|
||||||
|
|
||||||
def add_value_info_for_constants(model : onnx.ModelProto):
|
|
||||||
"""
|
|
||||||
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
|
||||||
that info explicitly as ValueInfoProtos.
|
|
||||||
Mutates the model.
|
|
||||||
Args:
|
|
||||||
model: The ModelProto to update.
|
|
||||||
"""
|
|
||||||
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
|
||||||
if model.ir_version < 4:
|
|
||||||
return
|
|
||||||
|
|
||||||
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
|
||||||
inputs = {i.name for i in graph.input}
|
|
||||||
existing_info = {vi.name: vi for vi in graph.value_info}
|
|
||||||
for init in graph.initializer:
|
|
||||||
# Check it really is a constant, not an input
|
|
||||||
if init.name in inputs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# The details we want to add
|
|
||||||
elem_type = init.data_type
|
|
||||||
shape = init.dims
|
|
||||||
|
|
||||||
# Get existing or create new value info for this constant
|
|
||||||
vi = existing_info.get(init.name)
|
|
||||||
if vi is None:
|
|
||||||
vi = graph.value_info.add()
|
|
||||||
vi.name = init.name
|
|
||||||
|
|
||||||
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
|
||||||
tt = vi.type.tensor_type
|
|
||||||
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
|
||||||
tt.elem_type = elem_type
|
|
||||||
if not tt.HasField("shape"):
|
|
||||||
# Ensure we set an empty list if the const is scalar (zero dims)
|
|
||||||
tt.shape.dim.extend([])
|
|
||||||
for dim in shape:
|
|
||||||
tt.shape.dim.add().dim_value = dim
|
|
||||||
|
|
||||||
# Handle subgraphs
|
|
||||||
for node in graph.node:
|
|
||||||
for attr in node.attribute:
|
|
||||||
# Ref attrs refer to other attrs, so we don't need to do anything
|
|
||||||
if attr.ref_attr_name != "":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if attr.type == onnx.AttributeProto.GRAPH:
|
|
||||||
add_const_value_infos_to_graph(attr.g)
|
|
||||||
if attr.type == onnx.AttributeProto.GRAPHS:
|
|
||||||
for g in attr.graphs:
|
|
||||||
add_const_value_infos_to_graph(g)
|
|
||||||
|
|
||||||
|
|
||||||
return add_const_value_infos_to_graph(model.graph)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
torch.backends.mlu.matmul.allow_tf32 = False
|
|
||||||
torch.backends.cnnl.allow_tf32 = False
|
|
||||||
modelname, batchsize, seqlen, export_path, dtype = parse_args()
|
|
||||||
if dtype == "tf32":
|
|
||||||
torch.backends.mlu.matmul.allow_tf32 = True
|
|
||||||
else:
|
|
||||||
os.environ["CAMBRICON_TF32_OVERRIDE"] = "0"
|
|
||||||
|
|
||||||
model, voc_size = get_model(modelname)
|
|
||||||
if export_path is not None:
|
|
||||||
filename = "{}_{}_{}_{}.onnx".format(modelname, batchsize, seqlen, dtype)
|
|
||||||
path = os.path.join(export_path, filename)
|
|
||||||
if not os.path.exists(path):
|
|
||||||
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
|
||||||
export_onnx(modelname, model, param, path, True, dtype)
|
|
||||||
else:
|
|
||||||
print("Onnx path exists, skipping export.")
|
|
||||||
|
|
||||||
run_pytorch(model, voc_size, batchsize, seqlen, dtype)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,39 +1,35 @@
|
||||||
import sys
|
|
||||||
sys.path.append('../')
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from pyinfinitensor.onnx import OnnxStub, backend
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
import onnx
|
import onnx
|
||||||
from onnx.external_data_helper import convert_model_to_external_data
|
|
||||||
from onnx.shape_inference import infer_shapes_path
|
from onnx.shape_inference import infer_shapes_path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parallel_opt import parallel_model
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--name", type=str, default="test", help="name of this instance."
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model", type=str, required=True, help="path to the ONNX model file."
|
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
|
||||||
|
help="path to the ONNX model file."
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen_std",
|
"--gen_std",
|
||||||
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="whether to generate the standard results.",
|
help="whether to generate the standard results.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print("arg setting: ", args)
|
print("arg setting: ", args)
|
||||||
return (
|
return (
|
||||||
|
@ -44,46 +40,39 @@ def parse_args():
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
args.length,
|
args.length,
|
||||||
args.gen_std,
|
args.gen_std,
|
||||||
args.type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, runtime, world_size=1, rank=0, n=10, data_type="default"):
|
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||||
stub = OnnxStub(model, runtime, matmul_compute_type=data_type)
|
stub = OnnxStub(model, runtime)
|
||||||
load_inputs(stub, world_size, rank)
|
load_inputs(stub, world_size, rank)
|
||||||
# stub.tune()
|
# stub.tune()
|
||||||
stub.run()
|
stub.run()
|
||||||
# get outputs
|
# get outputs
|
||||||
|
time.sleep(0.01)
|
||||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
|
begin = time.time()
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
stub.run()
|
stub.run()
|
||||||
begin = time.time()
|
|
||||||
for _ in range(n * 2):
|
|
||||||
stub.run()
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
avg_time = (end - begin) / (n * 2)
|
avg_time = (end - begin) / n
|
||||||
print(f"average time: {avg_time}")
|
print(f"average time: {avg_time}")
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def load_inputs(stub, world_size=1, rank=0):
|
|
||||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
|
||||||
input = np.load(f"./data/input_{i}.npy")
|
|
||||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
|
||||||
tensor.copyin_numpy(input)
|
|
||||||
else:
|
|
||||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
|
||||||
|
|
||||||
|
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||||
def run_and_compare(name, model, runtime, world_size=1, rank=0, data_type="default"):
|
|
||||||
results = np.load(f"./data/output.npy")
|
results = np.load(f"./data/output.npy")
|
||||||
outputs = run_model(model, runtime, world_size, rank, data_type=data_type)
|
outputs = run_model(model, runtime, world_size, rank)
|
||||||
print("outputs abs mean:", abs(outputs).mean())
|
print("answer argmax:", np.argmax(results))
|
||||||
print("max abs diff:", abs(outputs - results).max())
|
print("output argmax:", np.argmax(outputs))
|
||||||
|
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||||
|
getDiff(results, outputs)
|
||||||
|
|
||||||
|
|
||||||
def start_worker(
|
def start_worker(
|
||||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
|
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||||
):
|
):
|
||||||
dist_name = name + "_dist"
|
dist_name = name + "_dist"
|
||||||
model = parallel_model(model, world_size, rank)
|
model = parallel_model(model, world_size, rank)
|
||||||
|
@ -96,7 +85,7 @@ def start_worker(
|
||||||
save_as_external_data=True,
|
save_as_external_data=True,
|
||||||
location=extern_path,
|
location=extern_path,
|
||||||
)
|
)
|
||||||
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
runtime = backend.BangRuntime(local_rank)
|
runtime = backend.BangRuntime(local_rank)
|
||||||
# print("init comm")
|
# print("init comm")
|
||||||
runtime.init_comm(
|
runtime.init_comm(
|
||||||
|
@ -104,12 +93,13 @@ def start_worker(
|
||||||
world_size,
|
world_size,
|
||||||
rank,
|
rank,
|
||||||
)
|
)
|
||||||
run_and_compare(name, model, runtime, world_size, rank, data_type)
|
run_and_compare(name, model, runtime, world_size, rank)
|
||||||
|
|
||||||
|
|
||||||
def start_single(name, model, data_type):
|
def start_single(name, model):
|
||||||
runtime = backend.BangRuntime(0)
|
runtime = backend.BangRuntime(0)
|
||||||
run_and_compare(name, model, runtime, data_type=data_type)
|
run_and_compare(name, model, runtime)
|
||||||
|
|
||||||
|
|
||||||
def generate_input_output(model):
|
def generate_input_output(model):
|
||||||
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||||
|
@ -142,36 +132,55 @@ def generate_input_output(model):
|
||||||
np.save(f"./data/output", output)
|
np.save(f"./data/output", output)
|
||||||
|
|
||||||
|
|
||||||
|
def load_inputs(stub, world_size=1, rank=0):
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = np.load(f"./data/input_{i}.npy")
|
||||||
|
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
else:
|
||||||
|
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||||
|
|
||||||
|
def getDiff(base, test):
|
||||||
|
absolute_diff = np.abs(np.subtract(base, test))
|
||||||
|
max_absolute_diff = np.max(absolute_diff)
|
||||||
|
|
||||||
|
baseCopy = base.astype(np.float64).ravel()
|
||||||
|
testCopy = test.astype(np.float64).ravel()
|
||||||
|
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||||
|
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||||
|
max_relative_diff = upValue / downValue
|
||||||
|
print(f"Max absolute difference: {max_absolute_diff}\n"
|
||||||
|
f"Max relative difference: {max_relative_diff}")
|
||||||
|
return max_absolute_diff, max_relative_diff
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, data_type = parse_args()
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||||
data_type = "default" if data_type == "fp32" else data_type
|
|
||||||
|
|
||||||
model = onnx.load(model_path)
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
# generate standart output
|
# generate standart output
|
||||||
if gen_std:
|
if gen_std:
|
||||||
print(f"generate standard data for {name}.")
|
print("Generate inputs and outputs.")
|
||||||
# a small vocabulary size to fit all LLM.
|
p = mp.Process(target=generate_input_output, args=[model])
|
||||||
generate_input_output(model)
|
p.start()
|
||||||
|
p.join()
|
||||||
return
|
return
|
||||||
|
|
||||||
if nproc_per_node == 1:
|
# run single process.
|
||||||
# run single process.
|
# use standalone process to isolate cuda.
|
||||||
# use standalone process to isolate bang.
|
print("run model by single MLU.")
|
||||||
print("run model by single MLU.")
|
p = mp.Process(target=start_single, args=(name, model))
|
||||||
# p = mp.Process(target=start_single, args=(name, model, data_type))
|
p.start()
|
||||||
# p.start()
|
p.join()
|
||||||
# p.join()
|
|
||||||
start_single(name, model, data_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
# run distributed parallel.
|
# run distributed parallel.
|
||||||
world_size = nnodes * nproc_per_node
|
world_size = nnodes * nproc_per_node
|
||||||
print(f"run model by {world_size} MLU in parallel.")
|
print(f"run model by {world_size} MLUs in parallel.")
|
||||||
workers = [
|
workers = [
|
||||||
mp.Process(
|
mp.Process(
|
||||||
target=start_worker,
|
target=start_worker,
|
||||||
args=(name, world_size, rank, rank % nproc_per_node, model, data_type),
|
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||||
)
|
)
|
||||||
for rank in range(world_size)
|
for rank in range(world_size)
|
||||||
]
|
]
|
|
@ -1,14 +0,0 @@
|
||||||
export HF_ENDPOINT=https://hf-mirror.com
|
|
||||||
|
|
||||||
models=("bert" "gpt2" "llama")
|
|
||||||
batch_size=(1 32)
|
|
||||||
seq_len=(100 500)
|
|
||||||
nproc=(1 2 4)
|
|
||||||
|
|
||||||
for model in "${models[@]}"; do
|
|
||||||
for bs in "${batch_size[@]}"; do
|
|
||||||
for len in "${seq_len[@]}"; do
|
|
||||||
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" --export_onnx ../models/"$model" --export_only
|
|
||||||
done
|
|
||||||
done
|
|
||||||
done
|
|
|
@ -1,280 +0,0 @@
|
||||||
import sys
|
|
||||||
sys.path.append('../')
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import multiprocessing as mp
|
|
||||||
from pyinfinitensor.onnx import OnnxStub, backend
|
|
||||||
import onnx
|
|
||||||
from onnx.external_data_helper import convert_model_to_external_data
|
|
||||||
from onnx.shape_inference import infer_shapes_path
|
|
||||||
import numpy as np
|
|
||||||
from parallel_opt import parallel_model
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
|
||||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
|
||||||
parser.add_argument(
|
|
||||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--name", type=str, choices=["gpt2", "bert", "llama"], help="name of model."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", type=str, default="", help="path to the ONNX model file."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gen_std",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="whether to generate the standard results.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--run_single",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="whether run model with single process with standard inputs"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input_dir",
|
|
||||||
default="./",
|
|
||||||
help="path to save model input data"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--result_dir",
|
|
||||||
default="./",
|
|
||||||
help="path to save model standard output"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--internal_model_dir",
|
|
||||||
default="./",
|
|
||||||
help="path to save internal onnx model for parallel run"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# check path, mkdir if not exist
|
|
||||||
check_exists(args.input_dir)
|
|
||||||
check_exists(args.result_dir)
|
|
||||||
check_exists(args.internal_model_dir)
|
|
||||||
|
|
||||||
print("arg setting: ", args)
|
|
||||||
return (
|
|
||||||
args.num_nodes,
|
|
||||||
args.nproc_per_node,
|
|
||||||
args.name,
|
|
||||||
args.model,
|
|
||||||
args.gen_std,
|
|
||||||
args.run_single,
|
|
||||||
args.input_dir,
|
|
||||||
args.result_dir,
|
|
||||||
args.internal_model_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
utils function for this scripts
|
|
||||||
"""
|
|
||||||
def check_exists(path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
def np_assert(base, test, rtol=1e-2, atol=1e-1):
|
|
||||||
# np.testing.assert_allclose(test, base, rtol, atol)
|
|
||||||
print("max abs diff:", abs(base - test).max())
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Perf wrapper, run function n times
|
|
||||||
then average
|
|
||||||
"""
|
|
||||||
def perf_it(n):
|
|
||||||
def decorator(func):
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
# warmup
|
|
||||||
for _ in range(n):
|
|
||||||
func(*args, **kwargs)
|
|
||||||
|
|
||||||
t_total = 0
|
|
||||||
for _ in range(n):
|
|
||||||
t0 = time.time()
|
|
||||||
func(*args, **kwargs)
|
|
||||||
t1 = time.time()
|
|
||||||
t_total += t1 - t0
|
|
||||||
avg_time = (t_total) / n
|
|
||||||
print(f"Avg runtime of {n} time is {avg_time:.6f} seconds")
|
|
||||||
return avg_time
|
|
||||||
return wrapper
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Run InfiniTensor model with Standard input
|
|
||||||
check=True: check with standard output gen by pytorch
|
|
||||||
perf=True: run n times to get avg time
|
|
||||||
"""
|
|
||||||
def run_model(task_name,
|
|
||||||
model,
|
|
||||||
runtime,
|
|
||||||
world_size=1,
|
|
||||||
rank=0,
|
|
||||||
n=10,
|
|
||||||
check=True,
|
|
||||||
perf=True):
|
|
||||||
|
|
||||||
stub = OnnxStub(model, runtime,
|
|
||||||
use_naive_allocator=True \
|
|
||||||
if task_name == "llama" else False)
|
|
||||||
|
|
||||||
# load in Onnx model inputs
|
|
||||||
def load_inputs(stub: OnnxStub):
|
|
||||||
# check exists
|
|
||||||
inputs = []
|
|
||||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
|
||||||
input_path = os.path.join(input_dir, \
|
|
||||||
f"{task_name}_input_{i}.npy")
|
|
||||||
print(input_path)
|
|
||||||
if os.path.exists(input_path):
|
|
||||||
input = np.load(input_path)
|
|
||||||
else :
|
|
||||||
raise KeyError(f"{i} th input of model not exists")
|
|
||||||
# check shape
|
|
||||||
if all(x == y for x,y in zip(input.shape, tensor.shape())):
|
|
||||||
tensor.copyin_numpy(input)
|
|
||||||
else:
|
|
||||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
|
||||||
|
|
||||||
load_inputs(stub)
|
|
||||||
# stub.tune()
|
|
||||||
stub.run()
|
|
||||||
time.sleep(0.01)
|
|
||||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
|
||||||
|
|
||||||
# check output results with standard output
|
|
||||||
if check:
|
|
||||||
st_output_path = os.path.join(result_dir, \
|
|
||||||
f"{task_name}_output.npy")
|
|
||||||
assert os.path.exists(st_output_path) , \
|
|
||||||
"standard output not exists"
|
|
||||||
st_output = np.load(st_output_path)
|
|
||||||
if np.isnan(output).any():
|
|
||||||
print("Nan in output")
|
|
||||||
exit()
|
|
||||||
np_assert(st_output, output)
|
|
||||||
|
|
||||||
# perf
|
|
||||||
if perf:
|
|
||||||
@perf_it(n)
|
|
||||||
def perf_infinitensor(stub: OnnxStub):
|
|
||||||
stub.run()
|
|
||||||
perf_infinitensor(stub)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Start a worker in Parallel
|
|
||||||
"""
|
|
||||||
def start_worker(name: str,
|
|
||||||
world_size: int,
|
|
||||||
rank: int,
|
|
||||||
local_rank: int,
|
|
||||||
model: onnx.ModelProto):
|
|
||||||
|
|
||||||
dist_name = name + "_dist"
|
|
||||||
# partial a onnx model to world_size part
|
|
||||||
model = parallel_model(model, world_size, rank)
|
|
||||||
onnx.save(model, os.path.join(internal_model_dir, \
|
|
||||||
f"{dist_name}_rank{rank}.onnx"), save_as_external_data=True)
|
|
||||||
runtime = backend.KUNLUNRuntime(local_rank)
|
|
||||||
# print("init comm")
|
|
||||||
runtime.init_comm(
|
|
||||||
dist_name,
|
|
||||||
world_size,
|
|
||||||
rank,
|
|
||||||
)
|
|
||||||
run_model(name, model, runtime, world_size, rank)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
generate standard input/output with
|
|
||||||
sigle card run
|
|
||||||
"""
|
|
||||||
def gen_standard(task_name: str, model: onnx.ModelProto):
|
|
||||||
runtime = backend.KUNLUNRuntime(0)
|
|
||||||
stub = OnnxStub(model, runtime)
|
|
||||||
position_id = 0
|
|
||||||
# generate random input for model
|
|
||||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
|
||||||
input = tensor.copyout_numpy()
|
|
||||||
if np.issubdtype(input.dtype, np.integer):
|
|
||||||
if input.size == 1:
|
|
||||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
|
||||||
else:
|
|
||||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
|
||||||
elif input.dtype == np.bool_:
|
|
||||||
input = np.random.randint(0,2,size=input.shape) > 0
|
|
||||||
else:
|
|
||||||
if i == 0:
|
|
||||||
input = np.ones(input.shape).astype(input.dtype)
|
|
||||||
position_id = input.shape[-1] - 1
|
|
||||||
else:
|
|
||||||
input = np.random.rand(*input.shape).astype(input.dtype)
|
|
||||||
tensor.copyin_numpy(input)
|
|
||||||
np.save(os.path.join(input_dir, \
|
|
||||||
f"{task_name}_input_{i}.npy"), input)
|
|
||||||
stub.run()
|
|
||||||
# print(stub.outputs)
|
|
||||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
|
||||||
if np.isnan(output).any():
|
|
||||||
print("Nan in output")
|
|
||||||
exit()
|
|
||||||
np.save(os.path.join(result_dir, f"{task_name}_output.npy"), output)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
global input_dir, result_dir, internal_model_dir
|
|
||||||
|
|
||||||
nnodes, nproc_per_node, task_name, \
|
|
||||||
model_path, gen_std, run_single, \
|
|
||||||
input_dir, result_dir, internal_model_dir = parse_args()
|
|
||||||
|
|
||||||
# load input onnx model
|
|
||||||
model = onnx.load(model_path)
|
|
||||||
|
|
||||||
# generate standart output
|
|
||||||
if gen_std:
|
|
||||||
print("Generate inputs and outputs.")
|
|
||||||
gen_standard(task_name, model)
|
|
||||||
return
|
|
||||||
|
|
||||||
if run_single:
|
|
||||||
print("Run model by one GPU card.")
|
|
||||||
runtime = backend.KUNLUNRuntime(0)
|
|
||||||
run_model(task_name, model, runtime)
|
|
||||||
return
|
|
||||||
|
|
||||||
# run distributed parallel.
|
|
||||||
world_size = nnodes * nproc_per_node
|
|
||||||
print(f"Run model by {world_size} GPU in parallel.")
|
|
||||||
workers = [
|
|
||||||
mp.Process(
|
|
||||||
target=start_worker,
|
|
||||||
args=(task_name, world_size, rank, rank % nproc_per_node, model),
|
|
||||||
)
|
|
||||||
for rank in range(world_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
for w in workers:
|
|
||||||
w.start()
|
|
||||||
|
|
||||||
for w in workers:
|
|
||||||
w.join()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,36 +0,0 @@
|
||||||
export HF_ENDPOINT=https://hf-mirror.com
|
|
||||||
|
|
||||||
# models=("bert" "gpt2" "llama")
|
|
||||||
models=("bert" "gpt2")
|
|
||||||
batch_size=(1 32)
|
|
||||||
seq_len=(100 500)
|
|
||||||
nproc=(1 2 4)
|
|
||||||
|
|
||||||
results_dir="results"
|
|
||||||
|
|
||||||
if [ -d "$results_dir" ]; then
|
|
||||||
echo "directory ./$results_dir exists"
|
|
||||||
else
|
|
||||||
mkdir -p "$results_dir"
|
|
||||||
echo "mkdir $results_dir, logs saved there"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
for model in "${models[@]}"; do
|
|
||||||
for bs in "${batch_size[@]}"; do
|
|
||||||
for len in "${seq_len[@]}"; do
|
|
||||||
# run pytorch model
|
|
||||||
echo "Run pytorch $model with batch_size=$bs length=$len ."
|
|
||||||
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" #> results/"$model"_"$bs"_"$len"_pytorch
|
|
||||||
for n in "${nproc[@]}"; do
|
|
||||||
# run infinitensor
|
|
||||||
echo "Run $n parallel infinitensor "$model" with batch_size=$bs and length=$len ."
|
|
||||||
python kunlun_launch.py --name "$model" --model ../models/"$model"/"$model"_"$bs"_"$len".onnx --nproc_per_node=$n # >> results/"$model"_"$bs"_"$len"_infini
|
|
||||||
# delete internal files
|
|
||||||
find ./ -type f -name "*.onnx" -delete
|
|
||||||
find ./ -type f -name "*.pb" -delete
|
|
||||||
done
|
|
||||||
find ./ -type f -name "*.npy" -delete
|
|
||||||
done
|
|
||||||
done
|
|
||||||
done
|
|
|
@ -1,35 +0,0 @@
|
||||||
export HF_ENDPOINT=https://hf-mirror.com
|
|
||||||
|
|
||||||
# models=("bert" "gpt2" "llama")
|
|
||||||
models=("llama")
|
|
||||||
batch_size=(1 )
|
|
||||||
seq_len=(100 500)
|
|
||||||
nproc=(1 2 4)
|
|
||||||
|
|
||||||
results_dir="results"
|
|
||||||
|
|
||||||
if [ -d "$results_dir" ]; then
|
|
||||||
echo "directory ./$results_dir exists"
|
|
||||||
else
|
|
||||||
mkdir -p "$results_dir"
|
|
||||||
echo "mkdir $results_dir, logs saved there"
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
for model in "${models[@]}"; do
|
|
||||||
for bs in "${batch_size[@]}"; do
|
|
||||||
for len in "${seq_len[@]}"; do
|
|
||||||
echo "Run pytorch llama with batch_size="$bs" and length="$len""
|
|
||||||
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len"
|
|
||||||
for n in "${nproc[@]}"; do
|
|
||||||
# run pytorch model
|
|
||||||
echo "Run infinitensor llama with batch_size="$bs" and length="$len" and nproc="$n"."
|
|
||||||
python kunlun_launch.py --name llama --model ../models/llama/llama_"$bs"_"$len"_fp32.onnx --nproc_per_node=$n
|
|
||||||
# delete internal files
|
|
||||||
find ./ -type f -name "*.onnx" -delete
|
|
||||||
find ./ -type f -name "*0c" -delete
|
|
||||||
done
|
|
||||||
find ./ -type f -name "*.npy" -delete
|
|
||||||
done
|
|
||||||
done
|
|
||||||
done
|
|
|
@ -1,245 +0,0 @@
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
from transformers import BertModel, BertConfig
|
|
||||||
from transformers import GPT2Model, GPT2Config
|
|
||||||
from transformers import OPTModel, OPTConfig
|
|
||||||
from transformers import LlamaModel, LlamaConfig
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
import onnx
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from onnx.external_data_helper import convert_model_to_external_data
|
|
||||||
from onnxsim import simplify
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
|
||||||
torch.backends.cudnn.allow_tf32 = False
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", type=str, choices=["gpt2", "bert", "opt", "llama"], required=True, help="model type"
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
|
||||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--export_onnx",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default=None,
|
|
||||||
const="./",
|
|
||||||
help="whether and where to export onnx file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input_dir",
|
|
||||||
type=str,
|
|
||||||
default="./",
|
|
||||||
help="path to save pytorch model input data"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--result_dir",
|
|
||||||
type=str,
|
|
||||||
default="./",
|
|
||||||
help="path to save pytorch model output data"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--export_only",
|
|
||||||
action="store_true"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
print("arg setting: ", args)
|
|
||||||
return (
|
|
||||||
args.model,
|
|
||||||
args.batch_size,
|
|
||||||
args.length,
|
|
||||||
args.export_onnx,
|
|
||||||
args.input_dir,
|
|
||||||
args.result_dir,
|
|
||||||
args.export_only
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(modelname):
|
|
||||||
if modelname == "bert":
|
|
||||||
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
|
||||||
voc_size = BertConfig().vocab_size
|
|
||||||
elif modelname == "gpt2":
|
|
||||||
model = GPT2Model.from_pretrained("gpt2")
|
|
||||||
voc_size = GPT2Config().vocab_size
|
|
||||||
elif modelname == "opt":
|
|
||||||
model = OPTModel.from_pretrained("./opt-125m")
|
|
||||||
voc_size = OPTConfig().vocab_size
|
|
||||||
elif modelname == "llama":
|
|
||||||
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
voc_size = LlamaConfig().vocab_size
|
|
||||||
else :
|
|
||||||
raise KeyError(modelname)
|
|
||||||
|
|
||||||
model = model.eval()
|
|
||||||
return model, voc_size
|
|
||||||
|
|
||||||
def run_pytorch(torch_model, voc_size, batchsize, len, model_name):
|
|
||||||
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
|
||||||
np.save(os.path.join(input_dir, f"{model_name}_input_0.npy"), data)
|
|
||||||
inputs = torch.from_numpy(data).to("cuda")
|
|
||||||
torch_model = torch_model.to("cuda")
|
|
||||||
|
|
||||||
n_iter = 10
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(10):
|
|
||||||
outputs = torch_model(inputs)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
begin = time.time()
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(n_iter):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
outputs = torch_model(inputs)
|
|
||||||
#
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
avg_time = (end - begin) / n_iter
|
|
||||||
outputs = outputs.last_hidden_state.to("cpu")
|
|
||||||
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
|
||||||
print(f"average time: {avg_time}")
|
|
||||||
torch.cuda.memory.empty_cache()
|
|
||||||
np.save(os.path.join(result_dir, f"{model_name}_output.npy"), \
|
|
||||||
np.array(outputs))
|
|
||||||
print(f"Save input & output as {model_name}_input_0.npy and {model_name}_output.npy")
|
|
||||||
|
|
||||||
|
|
||||||
def export_onnx(model_name, model, data, path, extern=False):
|
|
||||||
# torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
|
||||||
|
|
||||||
if model_name != "llama":
|
|
||||||
onnx_model = onnx.load(path)
|
|
||||||
onnx_model, check = simplify(onnx_model,
|
|
||||||
skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
|
||||||
# skipped_optimizers=['fuse_qkv'])
|
|
||||||
assert check
|
|
||||||
add_value_info_for_constants(onnx_model)
|
|
||||||
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
|
||||||
if extern:
|
|
||||||
extern_path = path.replace('.onnx', '.pb')
|
|
||||||
if os.path.exists(extern_path):
|
|
||||||
os.remove(extern_path)
|
|
||||||
convert_model_to_external_data(
|
|
||||||
onnx_model,
|
|
||||||
all_tensors_to_one_file=True,
|
|
||||||
location=extern_path.split("/")[-1],
|
|
||||||
size_threshold=1024,
|
|
||||||
convert_attribute=False,
|
|
||||||
)
|
|
||||||
onnx.save(onnx_model, path)
|
|
||||||
else:
|
|
||||||
sys.path.append("onnxsim_large_model")
|
|
||||||
from onnx_utils import set_onnx_input_shape
|
|
||||||
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
|
|
||||||
|
|
||||||
in_model_path = path
|
|
||||||
out_model_path = in_model_path[:-5] + ".sim.onnx"
|
|
||||||
|
|
||||||
onnx_model = onnx.load(in_model_path)
|
|
||||||
print(f"load model from {in_model_path} success")
|
|
||||||
|
|
||||||
size_th_bytes = 1024 * 1024
|
|
||||||
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
|
|
||||||
print("compress model success")
|
|
||||||
|
|
||||||
onnx_model = set_onnx_input_shape(onnx_model, "")
|
|
||||||
tensor_size_threshold = f"1024KB"
|
|
||||||
skipped_optimizers = []
|
|
||||||
skipped_optimizers.append("eliminate_duplicate_initializer")
|
|
||||||
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
|
|
||||||
tensor_size_threshold=tensor_size_threshold)
|
|
||||||
if not check:
|
|
||||||
raise ValueError(f"simplify compressed model {in_model_path} failed")
|
|
||||||
|
|
||||||
print(f"simplify model success")
|
|
||||||
|
|
||||||
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
|
|
||||||
print(f"uncompress model success")
|
|
||||||
|
|
||||||
add_value_info_for_constants(onnx_model)
|
|
||||||
|
|
||||||
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
|
|
||||||
|
|
||||||
|
|
||||||
def add_value_info_for_constants(model : onnx.ModelProto):
|
|
||||||
"""
|
|
||||||
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
|
||||||
that info explicitly as ValueInfoProtos.
|
|
||||||
Mutates the model.
|
|
||||||
Args:
|
|
||||||
model: The ModelProto to update.
|
|
||||||
"""
|
|
||||||
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
|
||||||
if model.ir_version < 4:
|
|
||||||
return
|
|
||||||
|
|
||||||
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
|
||||||
inputs = {i.name for i in graph.input}
|
|
||||||
existing_info = {vi.name: vi for vi in graph.value_info}
|
|
||||||
for init in graph.initializer:
|
|
||||||
# Check it really is a constant, not an input
|
|
||||||
if init.name in inputs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# The details we want to add
|
|
||||||
elem_type = init.data_type
|
|
||||||
shape = init.dims
|
|
||||||
|
|
||||||
# Get existing or create new value info for this constant
|
|
||||||
vi = existing_info.get(init.name)
|
|
||||||
if vi is None:
|
|
||||||
vi = graph.value_info.add()
|
|
||||||
vi.name = init.name
|
|
||||||
|
|
||||||
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
|
||||||
tt = vi.type.tensor_type
|
|
||||||
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
|
||||||
tt.elem_type = elem_type
|
|
||||||
if not tt.HasField("shape"):
|
|
||||||
# Ensure we set an empty list if the const is scalar (zero dims)
|
|
||||||
tt.shape.dim.extend([])
|
|
||||||
for dim in shape:
|
|
||||||
tt.shape.dim.add().dim_value = dim
|
|
||||||
|
|
||||||
# Handle subgraphs
|
|
||||||
for node in graph.node:
|
|
||||||
for attr in node.attribute:
|
|
||||||
# Ref attrs refer to other attrs, so we don't need to do anything
|
|
||||||
if attr.ref_attr_name != "":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if attr.type == onnx.AttributeProto.GRAPH:
|
|
||||||
add_const_value_infos_to_graph(attr.g)
|
|
||||||
if attr.type == onnx.AttributeProto.GRAPHS:
|
|
||||||
for g in attr.graphs:
|
|
||||||
add_const_value_infos_to_graph(g)
|
|
||||||
|
|
||||||
|
|
||||||
return add_const_value_infos_to_graph(model.graph)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
global input_dir, result_dir
|
|
||||||
|
|
||||||
modelname, batchsize, seqlen, \
|
|
||||||
export_path, input_dir, result_dir, export_only = parse_args()
|
|
||||||
|
|
||||||
model, voc_size = get_model(modelname) # pytorch model
|
|
||||||
|
|
||||||
if export_path is not None:
|
|
||||||
os.makedirs(export_path, exist_ok=True)
|
|
||||||
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
|
|
||||||
path = os.path.join(export_path, filename)
|
|
||||||
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
|
||||||
export_onnx(modelname, model, param, path, True) # export pytorch model to onnx model
|
|
||||||
if export_only:
|
|
||||||
return
|
|
||||||
|
|
||||||
run_pytorch(model, voc_size, batchsize, seqlen, modelname)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -0,0 +1,213 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import multiprocessing as mp
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import onnx
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnx.shape_inference import infer_shapes_path
|
||||||
|
import numpy as np
|
||||||
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
st_input_dir = "standard/inputs/"
|
||||||
|
st_output_dir = "standard/outputs/"
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file."
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--gen_std",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to generate the standard results.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_single",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether run model with single process with standard inputs"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.num_nodes,
|
||||||
|
args.nproc_per_node,
|
||||||
|
args.name,
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.gen_std,
|
||||||
|
args.run_single
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
load_inputs(stub, world_size, rank)
|
||||||
|
# stub.tune()
|
||||||
|
stub.run()
|
||||||
|
# get outputs
|
||||||
|
time.sleep(0.01)
|
||||||
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
|
# bench
|
||||||
|
begin = time.time()
|
||||||
|
for _ in range(n):
|
||||||
|
stub.run()
|
||||||
|
end = time.time()
|
||||||
|
avg_time = (end - begin) / n
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||||
|
results = np.load(os.path.join(st_output_dir,f"output.npy"))
|
||||||
|
outputs = run_model(model, runtime, world_size, rank)
|
||||||
|
print(outputs[:100])
|
||||||
|
if np.isnan(outputs).any():
|
||||||
|
print("Nan in output")
|
||||||
|
print("answer argmax:", np.argmax(results))
|
||||||
|
print("output argmax:", np.argmax(outputs))
|
||||||
|
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||||
|
getDiff(results, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker(
|
||||||
|
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||||
|
):
|
||||||
|
dist_name = name + "_dist"
|
||||||
|
model = parallel_model(model, world_size, rank)
|
||||||
|
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
onnx.save_model(
|
||||||
|
model,
|
||||||
|
f"./{dist_name}_rank{rank}.onnx",
|
||||||
|
save_as_external_data=True,
|
||||||
|
location=extern_path,
|
||||||
|
)
|
||||||
|
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
|
runtime = backend.KUNLUNRuntime(local_rank)
|
||||||
|
# print("init comm")
|
||||||
|
runtime.init_comm(
|
||||||
|
dist_name,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
run_and_compare(name, model, runtime, world_size, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def start_single(name, model):
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
run_and_compare(name, model, runtime)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_input_output(model):
|
||||||
|
runtime = backend.KUNLUNRuntime(0)
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
position_id = 0
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = tensor.copyout_numpy()
|
||||||
|
if np.issubdtype(input.dtype, np.integer):
|
||||||
|
if input.size == 1:
|
||||||
|
# input = np.array([position_id])
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
else:
|
||||||
|
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||||
|
elif input.dtype == np.bool_:
|
||||||
|
input = np.random.randint(0,2,size=input.shape) > 0
|
||||||
|
else:
|
||||||
|
if i == 0:
|
||||||
|
input = np.ones(input.shape).astype(input.dtype)
|
||||||
|
position_id = input.shape[-1] - 1
|
||||||
|
else:
|
||||||
|
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
np.save(os.path.join(st_input_dir, f"input_{i}"), input)
|
||||||
|
stub.run()
|
||||||
|
# print(stub.outputs)
|
||||||
|
time.sleep(0.01)
|
||||||
|
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
print(output[:100])
|
||||||
|
if np.isnan(output).any():
|
||||||
|
print("Nan in output")
|
||||||
|
np.save(os.path.join(st_output_dir, f"output"), output)
|
||||||
|
|
||||||
|
|
||||||
|
def load_inputs(stub, world_size=1, rank=0):
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = np.load(os.path.join(st_input_dir, f"input_{i}.npy"))
|
||||||
|
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
else:
|
||||||
|
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||||
|
|
||||||
|
|
||||||
|
def getDiff(base, test):
|
||||||
|
absolute_diff = np.abs(np.subtract(base, test))
|
||||||
|
max_absolute_diff = np.max(absolute_diff)
|
||||||
|
|
||||||
|
baseCopy = base.astype(np.float64).ravel()
|
||||||
|
testCopy = test.astype(np.float64).ravel()
|
||||||
|
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||||
|
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||||
|
max_relative_diff = upValue / downValue
|
||||||
|
print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}")
|
||||||
|
|
||||||
|
return max_absolute_diff, max_relative_diff
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args()
|
||||||
|
|
||||||
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
|
# generate standart output
|
||||||
|
if gen_std:
|
||||||
|
print("Generate inputs and outputs.")
|
||||||
|
p = mp.Process(target=generate_input_output, args=[model])
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# # run single process.
|
||||||
|
# # use standalone process to isolate cuda.
|
||||||
|
if run_single:
|
||||||
|
print("run model by single GPU.")
|
||||||
|
p = mp.Process(target=start_single, args=(name, model))
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
return
|
||||||
|
|
||||||
|
# run distributed parallel.
|
||||||
|
world_size = nnodes * nproc_per_node
|
||||||
|
print(f"run model by {world_size} GPU in parallel.")
|
||||||
|
workers = [
|
||||||
|
mp.Process(
|
||||||
|
target=start_worker,
|
||||||
|
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||||
|
)
|
||||||
|
for rank in range(world_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.start()
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit cbcf3fbf985a00494b0f136c92eaccd42031bf65
|
|
|
@ -110,6 +110,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
s_dim = 0
|
s_dim = 0
|
||||||
elif in_plc.dim == 2:
|
elif in_plc.dim == 2:
|
||||||
s_dim = 1
|
s_dim = 1
|
||||||
|
|
||||||
assert s_dim != -1
|
assert s_dim != -1
|
||||||
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
||||||
out_dims[s_dim] //= tp_world_size
|
out_dims[s_dim] //= tp_world_size
|
||||||
|
|
|
@ -0,0 +1,512 @@
|
||||||
|
import os
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import onnx_graphsurgeon as gs
|
||||||
|
import time
|
||||||
|
import nvtx
|
||||||
|
import argparse
|
||||||
|
from mpi4py import MPI
|
||||||
|
from pytrie import StringTrie
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
IO,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
|
||||||
|
parser.add_argument('--layer', dest='n_layers', type=int, default=48)
|
||||||
|
parser.add_argument("--num_nodes", dest='num_nodes',
|
||||||
|
type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument("--world_size", dest="world_size",
|
||||||
|
type=int, default=1, help="")
|
||||||
|
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
|
||||||
|
type=int, default=1, help="number of processes per node")
|
||||||
|
parser.add_argument("--n_max_length", dest="n_max_length",
|
||||||
|
type=int, default=1024, help="number of processes per node")
|
||||||
|
parser.add_argument("--vocab_size", dest="vocab_size",
|
||||||
|
type=int, default=119696, help="vocabulary size")
|
||||||
|
parser.add_argument("--hidden_size", dest="hidden_size",
|
||||||
|
type=int, default=4096, help="vocabulary size")
|
||||||
|
parser.add_argument('--rank', dest='rank', type=int, default=0)
|
||||||
|
parser.add_argument('--speedup', action='store_true')
|
||||||
|
parser.add_argument('--no_cudagraph', action='store_true')
|
||||||
|
parser.add_argument('--fp16', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
comm = MPI.COMM_WORLD
|
||||||
|
args.rank = comm.Get_rank()
|
||||||
|
args.nproc_per_node = comm.Get_size()
|
||||||
|
args.world_size = args.num_nodes * args.nproc_per_node
|
||||||
|
|
||||||
|
ONNX_MODEL_PATH = "/data3/shared/xnsong/9G/dist/9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
|
||||||
|
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||||
|
|
||||||
|
weight_path = "9g_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
|
||||||
|
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||||
|
|
||||||
|
model_dir = "/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/cpm9g-11b-sft.pt"
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def RMSNorm(self, a, b):
|
||||||
|
return self.layer(op="RMSNorm", inputs=a, outputs=b)
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def RoPE(self, a, b):
|
||||||
|
return self.layer(op="RoPE", inputs=a, outputs=b)
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def AttentionKVCache(self, a, b):
|
||||||
|
return self.layer(op="AttentionKVCache", inputs=a, outputs=b)
|
||||||
|
|
||||||
|
def to_numpy(dict):
|
||||||
|
ret = dict
|
||||||
|
if args.fp16:
|
||||||
|
ret = np.float16(ret)
|
||||||
|
else:
|
||||||
|
ret = np.float32(ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def parallel(array, split='replicate'):
|
||||||
|
if args.world_size > 1 and split == 'partial_column':
|
||||||
|
return np.hsplit(array, args.world_size)[args.rank]
|
||||||
|
elif args.world_size > 1 and split == 'partial_row':
|
||||||
|
return np.vsplit(array, args.world_size)[args.rank]
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
def generate_onnx(ONNX_MODEL_PATH):
|
||||||
|
state_dict = torch.load(f'{model_dir}', map_location='cpu')
|
||||||
|
new_state_dict = {name: param.cpu().numpy()
|
||||||
|
for name, param in state_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
operators = []
|
||||||
|
graph = gs.Graph(nodes=operators)
|
||||||
|
gather_input = gs.Variable(name="gather_input.0", dtype=np.int64, shape=(1,1))
|
||||||
|
pos_input = gs.Variable(name="pos_input.0", dtype=np.int64, shape=(1,1))
|
||||||
|
|
||||||
|
embedding_weight = gs.Constant(name="embedding.weight", values=to_numpy(new_state_dict["input_embedding.weight"]))
|
||||||
|
gather_output = gs.Variable(name="gather_output.0")
|
||||||
|
gather = gs.Node(op="Gather", inputs=[embedding_weight, gather_input], outputs=[gather_output])
|
||||||
|
operators.append(gather)
|
||||||
|
input = gather_output
|
||||||
|
|
||||||
|
graph.inputs=[gather_input, pos_input]
|
||||||
|
graph.outputs=[]
|
||||||
|
|
||||||
|
for i in tqdm(range(args.n_layers)):
|
||||||
|
# global input
|
||||||
|
attn_kcache_input = gs.Variable(name="/layers." + str(i) + "/attn/kcache_input", dtype=np.float32, shape=(1,32,1023,128))
|
||||||
|
attn_vcache_input = gs.Variable(name="/layers." + str(i) + "/attn/vcache_input", dtype=np.float32, shape=(1,32,1023,128))
|
||||||
|
graph.inputs.append(attn_kcache_input)
|
||||||
|
graph.inputs.append(attn_vcache_input)
|
||||||
|
|
||||||
|
# weight
|
||||||
|
layernorm_0_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.0/mul_weight",
|
||||||
|
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".self_att.layernorm_before_attention.weight"]))
|
||||||
|
attn_qproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/qproj_weight",
|
||||||
|
values=parallel(
|
||||||
|
np.transpose(
|
||||||
|
to_numpy(
|
||||||
|
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_q.weight"]))
|
||||||
|
, 'partial_column'))
|
||||||
|
attn_kproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/kproj_weight",
|
||||||
|
values=parallel(
|
||||||
|
np.transpose(
|
||||||
|
to_numpy(
|
||||||
|
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_k.weight"]))
|
||||||
|
, 'partial_column'))
|
||||||
|
attn_vproj_weight = gs.Constant(name="/layers." + str(i) + "/attn/vproj_weight",
|
||||||
|
values=parallel(
|
||||||
|
np.transpose(
|
||||||
|
to_numpy(
|
||||||
|
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.project_v.weight"]))
|
||||||
|
, 'partial_column'))
|
||||||
|
attn_outmatmul_input = gs.Constant(name="/layers." + str(i) + "/attn/outmatmul_weight",
|
||||||
|
values=parallel(
|
||||||
|
np.transpose(
|
||||||
|
to_numpy(
|
||||||
|
new_state_dict["encoder.layers." + str(i) + ".self_att.self_attention.attention_out.weight"]))
|
||||||
|
, 'partial_row'))
|
||||||
|
|
||||||
|
layernorm_1_mul_weight = gs.Constant(name="/layers." + str(i) + "/layernorm.1/mul_weight",
|
||||||
|
values=to_numpy(new_state_dict["encoder.layers." + str(i) + ".ffn.layernorm_before_ffn.weight"]))
|
||||||
|
ffn_matmul_0_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_0_weight",
|
||||||
|
values=parallel(
|
||||||
|
np.transpose(
|
||||||
|
to_numpy(
|
||||||
|
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_0.weight"]))
|
||||||
|
, 'partial_column'))
|
||||||
|
ffn_matmul_1_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_1_weight",
|
||||||
|
values=parallel(
|
||||||
|
np.transpose(
|
||||||
|
to_numpy(
|
||||||
|
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_in.w_1.weight"]))
|
||||||
|
, 'partial_column'))
|
||||||
|
ffn_matmul_out_input = gs.Constant(name="/layers." + str(i) + "/ffn/matmul_out_weight",
|
||||||
|
values=parallel(
|
||||||
|
np.transpose(
|
||||||
|
to_numpy(
|
||||||
|
new_state_dict["encoder.layers." + str(i) + ".ffn.ffn.w_out.weight"]))
|
||||||
|
, 'partial_row'))
|
||||||
|
|
||||||
|
attn_qrope_output = gs.Variable(name="/layers." + str(i) + "/attn/qrope_output")
|
||||||
|
attn_krope_output = gs.Variable(name="/layers." + str(i) + "/attn/krope_output")
|
||||||
|
attn_kvcache_output = gs.Variable(name="/layers." + str(i) + "/attn/kvcache_output")
|
||||||
|
layernorm_0_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.0/mul_output_1")
|
||||||
|
layernorm_1_mul_output_1 = gs.Variable(name="/layers." + str(i) + "/layernorm.1/mul_output_1")
|
||||||
|
attn_qproj_output = gs.Variable(name="/layers." + str(i) + "/attn/qproj_output")
|
||||||
|
attn_kproj_output = gs.Variable(name="/layers." + str(i) + "/attn/kproj_output")
|
||||||
|
attn_vproj_output = gs.Variable(name="/layers." + str(i) + "/attn/vproj_output")
|
||||||
|
attn_outmatmul_output = gs.Variable(name="/layers." + str(i) + "/attn/outmatmul_output")
|
||||||
|
attn_outadd_output = gs.Variable(name="/layers." + str(i) + "/attn/outadd_output")
|
||||||
|
ffn_matmul_0_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_0_output")
|
||||||
|
ffn_silu_output = gs.Variable(name="/layers." + str(i) + "/ffn/silu_output")
|
||||||
|
ffn_matmul_1_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_1_output")
|
||||||
|
ffn_mul_output = gs.Variable(name="/layers." + str(i) + "/ffn/mul_output")
|
||||||
|
ffn_matmul_out_output = gs.Variable(name="/layers." + str(i) + "/ffn/matmul_out_output")
|
||||||
|
ffn_add_output = gs.Variable(name="/layers." + str(i) + "/ffn/add_output")
|
||||||
|
|
||||||
|
graph.RMSNorm([input, layernorm_0_mul_weight], [layernorm_0_mul_output_1])
|
||||||
|
attn_qproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_qproj_weight], outputs=[attn_qproj_output])
|
||||||
|
operators.append(attn_qproj)
|
||||||
|
attn_kproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_kproj_weight], outputs=[attn_kproj_output])
|
||||||
|
operators.append(attn_kproj)
|
||||||
|
attn_vproj = gs.Node(op="MatMul", inputs=[layernorm_0_mul_output_1, attn_vproj_weight], outputs=[attn_vproj_output])
|
||||||
|
operators.append(attn_vproj)
|
||||||
|
graph.RoPE([pos_input, attn_qproj_output], [attn_qrope_output])
|
||||||
|
graph.RoPE([pos_input, attn_kproj_output], [attn_krope_output])
|
||||||
|
graph.AttentionKVCache([attn_kcache_input, attn_vcache_input, attn_qrope_output, attn_krope_output, attn_vproj_output, pos_input],[attn_kvcache_output])
|
||||||
|
attn_outproj = gs.Node(op="MatMul", inputs=[attn_kvcache_output, attn_outmatmul_input], outputs=[attn_outmatmul_output])
|
||||||
|
operators.append(attn_outproj)
|
||||||
|
|
||||||
|
attn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/attn/reducesum_output")
|
||||||
|
if args.world_size > 1:
|
||||||
|
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/attn/reducesum",
|
||||||
|
inputs=[attn_outmatmul_output], outputs=[attn_reduce_sum_output],
|
||||||
|
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||||
|
graph.nodes.append(reduce_sum)
|
||||||
|
|
||||||
|
attn_outadd = gs.Node(op="Add", inputs=[input, attn_outmatmul_output if args.world_size == 1 else attn_reduce_sum_output], outputs=[attn_outadd_output])
|
||||||
|
operators.append(attn_outadd)
|
||||||
|
|
||||||
|
graph.RMSNorm([attn_outadd_output, layernorm_1_mul_weight], [layernorm_1_mul_output_1])
|
||||||
|
|
||||||
|
ffn_matmul_0 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_0_input], outputs=[ffn_matmul_0_output])
|
||||||
|
operators.append(ffn_matmul_0)
|
||||||
|
ffn_silu = gs.Node(op="Silu", inputs=[ffn_matmul_0_output], outputs=[ffn_silu_output])
|
||||||
|
operators.append(ffn_silu)
|
||||||
|
ffn_matmul_1 = gs.Node(op="MatMul", inputs=[layernorm_1_mul_output_1, ffn_matmul_1_input], outputs=[ffn_matmul_1_output])
|
||||||
|
operators.append(ffn_matmul_1)
|
||||||
|
ffn_mul = gs.Node(op="Mul", inputs=[ffn_silu_output, ffn_matmul_1_output], outputs=[ffn_mul_output])
|
||||||
|
operators.append(ffn_mul)
|
||||||
|
ffn_matmul_out = gs.Node(op="MatMul", inputs=[ffn_mul_output, ffn_matmul_out_input], outputs=[ffn_matmul_out_output])
|
||||||
|
operators.append(ffn_matmul_out)
|
||||||
|
|
||||||
|
ffn_reduce_sum_output = gs.Variable(name="/layers." + str(i) + "/ffn/reducesum_output")
|
||||||
|
if args.world_size > 1:
|
||||||
|
reduce_sum = gs.Node(op="ReduceSum", name="/layers." + str(i) + "/ffn/reducesum",
|
||||||
|
inputs=[ffn_matmul_out_output], outputs=[ffn_reduce_sum_output],
|
||||||
|
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||||
|
graph.nodes.append(reduce_sum)
|
||||||
|
|
||||||
|
ffn_add = gs.Node(op="Add", inputs=[attn_outadd_output, ffn_matmul_out_output if args.world_size == 1 else ffn_reduce_sum_output], outputs=[ffn_add_output])
|
||||||
|
operators.append(ffn_add)
|
||||||
|
input = ffn_add_output
|
||||||
|
|
||||||
|
layernorm_mul_weight = gs.Constant(name="/output/layernorm/mul_weight", values=to_numpy(new_state_dict["encoder.output_layernorm.weight"]))
|
||||||
|
layernorm_mul_output_1 = gs.Variable(name="/output/layernorm/mul_output_1")
|
||||||
|
|
||||||
|
graph.RMSNorm([input, layernorm_mul_weight], [layernorm_mul_output_1])
|
||||||
|
|
||||||
|
lm_head_weight = gs.Constant(name="/output/lm_head/weight", values=np.transpose(to_numpy(new_state_dict["lm_head.weight"])))
|
||||||
|
lm_head_output = gs.Variable(name="/output/lm_head/output")
|
||||||
|
lm_head = gs.Node(op="MatMul", inputs=[layernorm_mul_output_1, lm_head_weight], outputs=[lm_head_output])
|
||||||
|
operators.append(lm_head)
|
||||||
|
|
||||||
|
if args.fp16:
|
||||||
|
final_cast_output = gs.Variable(name="/output/cast/output", dtype=np.float32, shape=(1,1,args.vocab_size))
|
||||||
|
final_cast = gs.Node(op="Cast", inputs=[lm_head_output], outputs=[final_cast_output])
|
||||||
|
final_cast.attrs["to"] = np.float32
|
||||||
|
operators.append(final_cast)
|
||||||
|
graph.outputs.append(final_cast_output)
|
||||||
|
else:
|
||||||
|
lm_head_output.dtype=np.float32
|
||||||
|
lm_head_output.shape=(1,1,args.vocab_size)
|
||||||
|
graph.outputs.append(lm_head_output)
|
||||||
|
|
||||||
|
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True, location=weight_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
|
||||||
|
"""Loads a vocabulary file into a dictionary."""
|
||||||
|
vocab: Dict[str, int] = {}
|
||||||
|
|
||||||
|
reader = io.TextIOWrapper(fp, encoding="utf-8")
|
||||||
|
for token in reader.readlines():
|
||||||
|
token = token.strip()
|
||||||
|
if len(token) == 0:
|
||||||
|
continue
|
||||||
|
token = json.loads(token)
|
||||||
|
vocab[token] = len(vocab)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
class CPM9GTokenizer(object):
|
||||||
|
def __init__(self, path):
|
||||||
|
self.unk_token = "<unk>"
|
||||||
|
self.bos_token = "<s>"
|
||||||
|
self.eos_token = "</s>"
|
||||||
|
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
|
||||||
|
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
|
||||||
|
]
|
||||||
|
|
||||||
|
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
|
||||||
|
|
||||||
|
all_tokens = load_vocab(io.FileIO(path, "rb"))
|
||||||
|
|
||||||
|
self.encoder: Dict[str, int] = {}
|
||||||
|
self._special_encoder: Dict[str, int] = {}
|
||||||
|
for token, token_id in all_tokens.items():
|
||||||
|
if token in self._special_token_set:
|
||||||
|
self._special_encoder[token] = token_id
|
||||||
|
else:
|
||||||
|
self.encoder[token] = token_id
|
||||||
|
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
|
||||||
|
|
||||||
|
self._max_word_len = max([len(x) for x in self.encoder.keys()])
|
||||||
|
|
||||||
|
self._len_word_first = {}
|
||||||
|
for x in self.encoder.keys():
|
||||||
|
if not x[0] in self._len_word_first:
|
||||||
|
self._len_word_first[x[0]] = 1
|
||||||
|
if len(x) > self._len_word_first[x[0]]:
|
||||||
|
self._len_word_first[x[0]] = len(x)
|
||||||
|
self.tencoder = StringTrie(self.encoder)
|
||||||
|
|
||||||
|
def get_piece(self, text: str) -> str:
|
||||||
|
if text[0] in self._len_word_first:
|
||||||
|
text = text[: self._len_word_first[text[0]]]
|
||||||
|
len_text = len(text)
|
||||||
|
for i in range(len(text)):
|
||||||
|
sub = text[: len_text - i]
|
||||||
|
if sub in self.encoder:
|
||||||
|
return sub
|
||||||
|
return text[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self):
|
||||||
|
return self._special_encoder[self.eos_token]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_id(self):
|
||||||
|
return self._special_encoder[self.bos_token]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unk_id(self):
|
||||||
|
return self._special_encoder[self.unk_token]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.encoder) + len(self._special_encoder)
|
||||||
|
|
||||||
|
def tokenize(self, text: str) -> List[str]:
|
||||||
|
output_tokens: List[str] = []
|
||||||
|
st = 0
|
||||||
|
while st < len(text):
|
||||||
|
piece = self.get_piece(text[st:])
|
||||||
|
output_tokens.append(piece)
|
||||||
|
st += len(piece)
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def escape(text: str) -> str:
|
||||||
|
return text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unescape(text: str) -> str:
|
||||||
|
return text
|
||||||
|
|
||||||
|
def encode(self, text: str, with_bos = True) -> List[int]:
|
||||||
|
ret = []
|
||||||
|
if with_bos:
|
||||||
|
ret.append(self.bos_id)
|
||||||
|
for x in self.tokenize(text):
|
||||||
|
if x in self.encoder:
|
||||||
|
ret.append(self.encoder[x])
|
||||||
|
else:
|
||||||
|
ret.extend(self._encode_unicode(x))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int]):
|
||||||
|
"""Decode ids into a string."""
|
||||||
|
ret = []
|
||||||
|
st = 0
|
||||||
|
while st < len(tokens):
|
||||||
|
if tokens[st] in self.decoder:
|
||||||
|
ret.append(self.decoder[tokens[st]])
|
||||||
|
st += 1
|
||||||
|
elif tokens[st] in self._byte_decoder:
|
||||||
|
first = self._byte_decoder[tokens[st]]
|
||||||
|
length = 1 if first < 128 else len(re.search('^1+0', bin(first)[2:])[0])-1
|
||||||
|
code = 0
|
||||||
|
try:
|
||||||
|
for j in range(length):
|
||||||
|
code = code << 8 | self._byte_decoder[tokens[st + j]]
|
||||||
|
code = int.to_bytes(code, length, "big").decode("utf-8")
|
||||||
|
ret.append(code)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
st = st + length
|
||||||
|
elif tokens[st] == self.eos_id:
|
||||||
|
ret.append(self.eos_token)
|
||||||
|
st += 1
|
||||||
|
elif tokens[st] == self.bos_id:
|
||||||
|
ret.append(self.bos_token)
|
||||||
|
st += 1
|
||||||
|
else:
|
||||||
|
ret.append(self.unk_token)
|
||||||
|
st += 1
|
||||||
|
return "".join(ret)
|
||||||
|
|
||||||
|
def _encode_unicode(self, token):
|
||||||
|
# wrap unicode encoding into a helper function
|
||||||
|
ids = []
|
||||||
|
utf8_id = token.encode("utf-8")
|
||||||
|
for _id in utf8_id:
|
||||||
|
ids.append(self._special_encoder[self.byte_list[_id]])
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def next_token(self, text):
|
||||||
|
# fast next token matching
|
||||||
|
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
|
||||||
|
if token is None:
|
||||||
|
token = text[0]
|
||||||
|
token_ids = self._encode_unicode(token)
|
||||||
|
else:
|
||||||
|
token_ids = [token_id]
|
||||||
|
return token, token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker(
|
||||||
|
world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, query
|
||||||
|
):
|
||||||
|
model = onnx.load(ONNX_MODEL_PATH)
|
||||||
|
runtime = backend.CudaRuntime(local_rank)
|
||||||
|
if args.nproc_per_node > 1:
|
||||||
|
runtime.init_comm(
|
||||||
|
"9g",
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
print("[{}] comm init.".format(rank))
|
||||||
|
|
||||||
|
stub = OnnxStub(model, runtime)
|
||||||
|
print("[{}] stub init.".format(rank))
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
if args.no_cudagraph:
|
||||||
|
stub.run()
|
||||||
|
else:
|
||||||
|
stub.run_with_cudagraph()
|
||||||
|
print("[{}] stub warmup.".format(rank))
|
||||||
|
|
||||||
|
tokenizer = CPM9GTokenizer("/data1/shared/9G-Infer/models/11B-Chat-QY-epoch-8/vocabs.txt")
|
||||||
|
query = tokenizer.encode(query)
|
||||||
|
|
||||||
|
output_tokens = []
|
||||||
|
for i in range(len(query)):
|
||||||
|
q = np.array(query[i])
|
||||||
|
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
|
||||||
|
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
|
||||||
|
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
|
||||||
|
|
||||||
|
if args.no_cudagraph:
|
||||||
|
stub.run()
|
||||||
|
else:
|
||||||
|
stub.run_with_cudagraph()
|
||||||
|
|
||||||
|
if i == len(query) - 1:
|
||||||
|
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
|
||||||
|
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
|
||||||
|
q = np.argmax(output)
|
||||||
|
output_tokens.append(q)
|
||||||
|
|
||||||
|
avg_time = 0
|
||||||
|
count = 0
|
||||||
|
while i < 1000:
|
||||||
|
count = count + 1
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with nvtx.annotate("gen {}-th token".format(i), color="red"):
|
||||||
|
i = i + 1
|
||||||
|
(list(stub.inputs.items()))[0][1].copyin_int64(q.reshape(-1).tolist())
|
||||||
|
pos = i * np.ones((args.batchsize, 1), dtype=np.int64)
|
||||||
|
(list(stub.inputs.items()))[1][1].copyin_int64(pos.reshape(-1).tolist())
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
if args.no_cudagraph:
|
||||||
|
stub.run()
|
||||||
|
else:
|
||||||
|
stub.run_with_cudagraph()
|
||||||
|
t1 = time.time()
|
||||||
|
avg_time += t1 - t0
|
||||||
|
|
||||||
|
output = np.array((list(stub.outputs.items()))[-1][1].copyout_float16()) if False \
|
||||||
|
else np.array((list(stub.outputs.items()))[-1][1].copyout_float())
|
||||||
|
|
||||||
|
# print(output)
|
||||||
|
|
||||||
|
with nvtx.annotate("argmax".format(i), color="green"):
|
||||||
|
q = np.argmax(output)
|
||||||
|
if q == 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
output_tokens.append(q)
|
||||||
|
avg_time = avg_time / count
|
||||||
|
print("avg_time_cost =", avg_time*1000, "ms")
|
||||||
|
text = tokenizer.decode(output_tokens)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
comm = MPI.COMM_WORLD
|
||||||
|
args.rank = comm.Get_rank()
|
||||||
|
args.nproc_per_node = comm.Get_size()
|
||||||
|
world_size = args.num_nodes * args.nproc_per_node
|
||||||
|
|
||||||
|
if not os.path.exists(ONNX_MODEL_PATH):
|
||||||
|
print("exporting onnx graph")
|
||||||
|
generate_onnx(ONNX_MODEL_PATH)
|
||||||
|
else:
|
||||||
|
print("will use exsiting onnx graph")
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_PATH)
|
||||||
|
print("data loaded")
|
||||||
|
|
||||||
|
|
||||||
|
#query = '''Beijing is the captial'''
|
||||||
|
#query = '''什么是PTX?'''
|
||||||
|
#query = '''生病了怎么办?'''
|
||||||
|
#query = '''Happy'''
|
||||||
|
query = '''def gcd(a, b):'''
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# infinitensor dist
|
||||||
|
####################################
|
||||||
|
# run distributed parallel.
|
||||||
|
pred = start_worker(world_size, args.rank, args.rank %
|
||||||
|
args.nproc_per_node, onnx_model, query)
|
||||||
|
if args.rank == 0:
|
||||||
|
print("输入:\n\n", query, "\n")
|
||||||
|
print("输出:", pred)
|
|
@ -0,0 +1,491 @@
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
|
||||||
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import onnx
|
||||||
|
import onnx_graphsurgeon as gs
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import time
|
||||||
|
import nvtx
|
||||||
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
|
||||||
|
parser.add_argument('--layer', dest='n_layers', type=int, default=32)
|
||||||
|
parser.add_argument("--num_nodes", dest='num_nodes',
|
||||||
|
type=int, default=1, help="number of nodes")
|
||||||
|
parser.add_argument("--nproc_per_node", dest="nproc_per_node",
|
||||||
|
type=int, default=1, help="number of processes per node")
|
||||||
|
parser.add_argument("--world_size", dest="world_size",
|
||||||
|
type=int, default=1, help="")
|
||||||
|
parser.add_argument("--n_max_length", dest="n_max_length",
|
||||||
|
type=int, default=1024, help="")
|
||||||
|
parser.add_argument("--vocab_size", dest="vocab_size",
|
||||||
|
type=int, default=32000, help="vocabulary size")
|
||||||
|
parser.add_argument("--hidden_size", dest="hidden_size",
|
||||||
|
type=int, default=4096)
|
||||||
|
parser.add_argument("--head_size", dest="head_size",
|
||||||
|
type=int, default=32)
|
||||||
|
parser.add_argument("--head_dim", dest="head_dim",
|
||||||
|
type=int, default=128)
|
||||||
|
parser.add_argument('--rank', dest='rank', type=int, default=0)
|
||||||
|
parser.add_argument('--no_cudagraph', action='store_true')
|
||||||
|
parser.add_argument('--fp16', action='store_true')
|
||||||
|
parser.add_argument('--is_1st_graph', action='store_true')
|
||||||
|
parser.add_argument('--speedup', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
comm = MPI.COMM_WORLD
|
||||||
|
args.rank = comm.Get_rank()
|
||||||
|
args.nproc_per_node = comm.Get_size()
|
||||||
|
args.world_size = args.num_nodes * args.nproc_per_node
|
||||||
|
|
||||||
|
PRETRAINED_LLAMA_PATH = "/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/"
|
||||||
|
ONNX_MODEL_PATH = "/data3/shared/xnsong/llama2/" + ("1st" if args.is_1st_graph else "2nd")
|
||||||
|
ONNX_MODEL_ORIGIN_PATH = ONNX_MODEL_PATH + "/origin/llama2_origin_bs{}_layer{}.onnx".format(
|
||||||
|
args.batchsize, args.n_layers)
|
||||||
|
ONNX_MODEL_SIM_PATH = ONNX_MODEL_PATH + "/sim/llama2_sim_bs{}_layer{}.onnx".format(
|
||||||
|
args.batchsize, args.n_layers)
|
||||||
|
ONNX_MODEL_FUSION_PATH = ONNX_MODEL_PATH + "/fusion/llama2_fusion_bs{}_layer{}.onnx".format(
|
||||||
|
args.batchsize, args.n_layers)
|
||||||
|
ONNX_MODEL_SPECIAL_PATH = ONNX_MODEL_PATH + "/special/llama2_special_bs{}_layer{}.onnx".format(
|
||||||
|
args.batchsize, args.n_layers)
|
||||||
|
ONNX_MODEL_FP16_PATH = ONNX_MODEL_PATH + "/fp16/llama2_fp16_bs{}_layer{}.onnx".format(
|
||||||
|
args.batchsize, args.n_layers)
|
||||||
|
ONNX_MODEL_DIST_PATH = ONNX_MODEL_PATH + "/dist/llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.onnx".format(
|
||||||
|
args.batchsize, args.n_layers, 16 if args.fp16 else 32, args.world_size, args.rank)
|
||||||
|
|
||||||
|
def parallel_model(onnx_model, world_size, rank):
|
||||||
|
graph = gs.import_onnx(onnx_model)
|
||||||
|
tmap = graph.tensors()
|
||||||
|
|
||||||
|
for i in range(args.n_layers):
|
||||||
|
tmap[graph.inputs[2+i*2].name].shape[1] = tmap[graph.inputs[2+i*2].name].shape[1]//world_size
|
||||||
|
tmap[graph.inputs[3+i*2].name].shape[1] = tmap[graph.inputs[3+i*2].name].shape[1]//world_size
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.name == "/model/layers." + str(i) + "/self_attn/q_proj/MatMul":
|
||||||
|
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/self_attn/k_proj/MatMul":
|
||||||
|
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/self_attn/v_proj/MatMul":
|
||||||
|
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/self_attn/o_proj/MatMul":
|
||||||
|
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
|
||||||
|
reduce_sum_output = gs.Variable("reduce_sum_output_" + str(i) + "_0",
|
||||||
|
dtype=np.float32)
|
||||||
|
reduce_sum = gs.Node(op="ReduceSum", name="reduce_sum_"+str(i)+"_0",
|
||||||
|
inputs=node.outputs, outputs=[reduce_sum_output],
|
||||||
|
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||||
|
graph.nodes.append(reduce_sum)
|
||||||
|
next_node = node.outputs[0].outputs[0]
|
||||||
|
next_node.inputs[1] = reduce_sum_output
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_0" or \
|
||||||
|
node.name == "/model/layers." + str(i) + "/self_attn/Reshape_1":
|
||||||
|
node.inputs[1].values = np.array(
|
||||||
|
[1, 1,
|
||||||
|
args.head_size//world_size,
|
||||||
|
args.hidden_size//args.head_size])
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_2":
|
||||||
|
node.inputs[1] = gs.Constant(name="/model/layers."+str(i)+"/self_attn/vreshape_input",
|
||||||
|
values=np.array(
|
||||||
|
[1, 1,
|
||||||
|
args.head_size//world_size,
|
||||||
|
args.hidden_size//args.head_size]))
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/self_attn/Reshape_3":
|
||||||
|
node.inputs[1] = gs.Constant(name="/model/layers." + str(i) + "/self_attn/Reshape_3_shape",
|
||||||
|
values=np.array(
|
||||||
|
[1, 1, args.hidden_size//world_size]))
|
||||||
|
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/mlp/up_proj/MatMul":
|
||||||
|
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/mlp/gate_proj/MatMul":
|
||||||
|
node.inputs[1].values = np.hsplit(node.inputs[1].values, world_size)[rank]
|
||||||
|
elif node.name == "/model/layers." + str(i) + "/mlp/down_proj/MatMul":
|
||||||
|
node.inputs[1].values = np.vsplit(node.inputs[1].values, world_size)[rank]
|
||||||
|
reduce_sum_output_1 = gs.Variable("reduce_sum_output_" + str(i) + "_1",
|
||||||
|
dtype=np.float32)
|
||||||
|
reduce_sum_1 = gs.Node(op="ReduceSum", inputs=node.outputs, outputs=[reduce_sum_output_1],
|
||||||
|
attrs={"noop_with_empty_axes":1, "communicator":0})
|
||||||
|
graph.nodes.append(reduce_sum_1)
|
||||||
|
next_node = node.outputs[0].outputs[0]
|
||||||
|
next_node.inputs[1] = reduce_sum_output_1
|
||||||
|
|
||||||
|
# new_out_1 = tmap["/model/layers.0/mlp/down_proj/MatMul_output_0"] #reduce_sum_output
|
||||||
|
# new_out_1.dtype = np.float32
|
||||||
|
# new_out_1.shape = [1,1,4096]
|
||||||
|
# graph.outputs.append(new_out_1)
|
||||||
|
graph.cleanup(True).toposort()
|
||||||
|
return gs.export_onnx(graph)
|
||||||
|
|
||||||
|
def simplify(onnx_model):
|
||||||
|
graph = gs.import_onnx(onnx_model)
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == "Cast":
|
||||||
|
inp_node = node.i()
|
||||||
|
inp_node.outputs = node.outputs
|
||||||
|
node.outputs.clear()
|
||||||
|
|
||||||
|
for i in range(args.n_layers):
|
||||||
|
nodename = "/model/layers." + str(i) + "/self_attn/Add_2"
|
||||||
|
node = [node for node in graph.nodes if node.name == nodename][0]
|
||||||
|
inp_node = node.i()
|
||||||
|
inp_node.outputs = node.outputs
|
||||||
|
node.outputs.clear()
|
||||||
|
|
||||||
|
graph.cleanup().toposort()
|
||||||
|
return gs.export_onnx(graph)
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def replace_with_RMSNorm(self, inputs, outputs):
|
||||||
|
inputs[0].outputs.pop(0)
|
||||||
|
inputs[0].outputs.pop(0)
|
||||||
|
|
||||||
|
for out in outputs:
|
||||||
|
out.inputs.clear()
|
||||||
|
return self.layer(op="RMSNorm", inputs=inputs, outputs=outputs, name="rmsnorm")
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def replace_with_silu(self, inputs, outputs):
|
||||||
|
for inp in inputs:
|
||||||
|
inp.outputs.clear()
|
||||||
|
for out in outputs:
|
||||||
|
out.inputs.clear()
|
||||||
|
return self.layer(op="Silu", inputs=inputs, outputs=outputs, name="silu")
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def replace_with_RoPE(self, a, b):
|
||||||
|
return self.layer(op="RoPE", inputs=a, outputs=b, name="rope")
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed):
|
||||||
|
for inp in inputs:
|
||||||
|
inp.outputs.clear()
|
||||||
|
for out in outputs:
|
||||||
|
out.inputs.clear()
|
||||||
|
for inp in inputs_added:
|
||||||
|
inputs.append(inp)
|
||||||
|
for out in outputs_removed:
|
||||||
|
out.inputs.clear()
|
||||||
|
return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs, name="attention")
|
||||||
|
|
||||||
|
def fusion(model):
|
||||||
|
graph = gs.import_onnx(model)
|
||||||
|
tmap = graph.tensors()
|
||||||
|
|
||||||
|
tmap["onnx::Reshape_1"].outputs.clear()
|
||||||
|
|
||||||
|
inputs = [tmap["/model/layers.0/input_layernorm/Cast_output_0"], tmap["model.layers.0.input_layernorm.weight"]]
|
||||||
|
rmsnorm_outputs = [tmap["/model/layers.0/input_layernorm/Mul_1_output_0"]]
|
||||||
|
graph.replace_with_RMSNorm(inputs, rmsnorm_outputs)
|
||||||
|
|
||||||
|
for i in range(args.n_layers):
|
||||||
|
# rotary embedding op
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"].inputs.clear()
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"].inputs.clear()
|
||||||
|
attn_qreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/qreshape_input",
|
||||||
|
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
|
||||||
|
attn_kreshape_input = gs.Constant(name="/model/layers." + str(i) + "/self_attn/kreshape_input",
|
||||||
|
values=np.array([1,1,args.head_size,args.hidden_size//args.head_size]))
|
||||||
|
attn_qrope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qrope_output")
|
||||||
|
attn_krope_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/krope_output")
|
||||||
|
attn_qreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/qreshape_output")
|
||||||
|
attn_kreshape_output = gs.Variable(name="/model/layers." + str(i) + "/self_attn/kreshape_output")
|
||||||
|
|
||||||
|
attn_qreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_0", inputs=[attn_qrope_output, attn_qreshape_input], outputs=[attn_qreshape_output])
|
||||||
|
attn_kreshape = gs.Node(op="Reshape", name = "/model/layers." + str(i) + "/self_attn/Reshape_1", inputs=[attn_krope_output, attn_kreshape_input], outputs=[attn_kreshape_output])
|
||||||
|
attn_qtrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_qreshape_output],
|
||||||
|
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"]])
|
||||||
|
attn_ktrans = gs.Node(op="Transpose", attrs={"perm":np.array([0,2,1,3])}, inputs=[attn_kreshape_output],
|
||||||
|
outputs=[tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"]])
|
||||||
|
|
||||||
|
graph.nodes.append(attn_qreshape)
|
||||||
|
graph.nodes.append(attn_kreshape)
|
||||||
|
graph.nodes.append(attn_qtrans)
|
||||||
|
graph.nodes.append(attn_ktrans)
|
||||||
|
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/q_proj/MatMul_output_0"]]
|
||||||
|
graph.replace_with_RoPE(inputs, [attn_qrope_output])
|
||||||
|
inputs = [tmap["onnx::Reshape_1"], tmap["/model/layers." + str(i) + "/self_attn/k_proj/MatMul_output_0"]]
|
||||||
|
graph.replace_with_RoPE(inputs, [attn_krope_output])
|
||||||
|
|
||||||
|
# rms-norm op
|
||||||
|
inputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Cast_output_0"], \
|
||||||
|
tmap["model.layers." + str(i) + ".post_attention_layernorm.weight"]]
|
||||||
|
outputs = [tmap["/model/layers." + str(i) + "/post_attention_layernorm/Mul_1_output_0"]]
|
||||||
|
graph.replace_with_RMSNorm(inputs, outputs)
|
||||||
|
inputs = [tmap["/model/layers." + str(i+1) + "/input_layernorm/Cast_output_0"] if i != args.n_layers-1 else \
|
||||||
|
tmap["/model/norm/Cast_output_0"], \
|
||||||
|
tmap["model.layers." + str(i+1) + ".input_layernorm.weight"] if i != args.n_layers-1 else \
|
||||||
|
tmap["model.norm.weight"]]
|
||||||
|
outputs = [tmap["/model/layers."+ str(i+1) + "/input_layernorm/Mul_1_output_0"]] if i != args.n_layers-1 else \
|
||||||
|
[tmap["/model/norm/Mul_1_output_0"]]
|
||||||
|
graph.replace_with_RMSNorm(inputs, outputs)
|
||||||
|
|
||||||
|
# silu op
|
||||||
|
inputs = [tmap["/model/layers." + str(i) + "/mlp/gate_proj/MatMul_output_0"]]
|
||||||
|
outputs = [tmap["/model/layers." + str(i) + "/mlp/act_fn/Mul_output_0"]]
|
||||||
|
graph.replace_with_silu(inputs, outputs)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
tmap["onnx::Concat_" + str((i+1)*2)],
|
||||||
|
tmap["onnx::Concat_" + str((i+1)*2+1)],
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"],
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
|
||||||
|
outputs = [
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],]
|
||||||
|
|
||||||
|
inputs_added = [graph.inputs[1]]
|
||||||
|
outputs_removed = []
|
||||||
|
graph.replace_with_attention(
|
||||||
|
inputs, outputs, inputs_added, outputs_removed)
|
||||||
|
|
||||||
|
graph.outputs = [tmap[graph.outputs[0].name]]
|
||||||
|
graph.cleanup(True).toposort()
|
||||||
|
|
||||||
|
return gs.export_onnx(graph)
|
||||||
|
|
||||||
|
def special_pass(model):
|
||||||
|
graph = gs.import_onnx(model)
|
||||||
|
tmap = graph.tensors()
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == "Transpose" or node.op == "Reshape":
|
||||||
|
inp_node = node.i()
|
||||||
|
inp_node.outputs = node.outputs
|
||||||
|
node.outputs.clear()
|
||||||
|
graph.cleanup(True).toposort()
|
||||||
|
return gs.export_onnx(graph)
|
||||||
|
|
||||||
|
def convert_to_fp16(model):
|
||||||
|
graph = gs.import_onnx(model)
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node.op == "Gather" and node.name == "/model/embed_tokens/Gather":
|
||||||
|
node.inputs[0].values = np.float16(node.inputs[0].values)
|
||||||
|
|
||||||
|
if node.op == "RMSNorm":
|
||||||
|
node.inputs[1].values = np.float16(node.inputs[1].values)
|
||||||
|
|
||||||
|
if node.op == "MatMul":
|
||||||
|
node.inputs[1].values = np.float16(node.inputs[1].values)
|
||||||
|
if node.name == "/lm_head/MatMul":
|
||||||
|
cast_1_out = gs.Variable(node.name+"_cast_out_output_0", dtype=np.float32, shape=node.outputs[0].shape)
|
||||||
|
cast_1 = gs.Node(op="Cast", inputs=[node.outputs[0]], outputs=[cast_1_out])
|
||||||
|
cast_1.attrs["to"] = np.float32
|
||||||
|
cast_1.name = node.name+"_cast_out_0"
|
||||||
|
graph.nodes.append(cast_1)
|
||||||
|
graph.outputs[0] = cast_1_out
|
||||||
|
node.outputs[0].dtype = np.float16
|
||||||
|
|
||||||
|
graph.cleanup(True).toposort()
|
||||||
|
return gs.export_onnx(graph)
|
||||||
|
|
||||||
|
def export_onnx(model: AutoModelForCausalLM):
|
||||||
|
if not os.path.exists(ONNX_MODEL_ORIGIN_PATH):
|
||||||
|
print("exporting origin onnx model...")
|
||||||
|
with torch.no_grad():
|
||||||
|
param = torch.zeros(
|
||||||
|
(args.batchsize, model.config.max_position_embeddings-1), dtype=torch.long)
|
||||||
|
logits = model(param, past_key_values=None)
|
||||||
|
|
||||||
|
if not args.is_1st_graph:
|
||||||
|
param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long)
|
||||||
|
torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values,
|
||||||
|
"position_ids": param_kvcache}), \
|
||||||
|
ONNX_MODEL_ORIGIN_PATH, verbose=False,
|
||||||
|
do_constant_folding=True,)
|
||||||
|
else:
|
||||||
|
position_ids = torch.tile(torch.arange(0, model.config.max_position_embeddings-1), (args.batchsize, 1))
|
||||||
|
attention_mask = torch.ones((args.batchsize, model.config.max_position_embeddings-1), dtype=torch.bool)
|
||||||
|
torch.onnx.export(model, (param, {"attention_mask": attention_mask,
|
||||||
|
"position_ids": position_ids}),\
|
||||||
|
ONNX_MODEL_ORIGIN_PATH, verbose=False,
|
||||||
|
do_constant_folding=True,)
|
||||||
|
print("export origin onnx finished.")
|
||||||
|
|
||||||
|
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SIM_PATH):
|
||||||
|
print("exporting sim onnx model...")
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_ORIGIN_PATH)
|
||||||
|
onnx_model = simplify(onnx_model)
|
||||||
|
onnx.save(onnx_model, ONNX_MODEL_SIM_PATH, save_as_external_data=True, \
|
||||||
|
location="llama2_sim_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||||
|
print("exporting sim onnx model finished.")
|
||||||
|
|
||||||
|
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_FUSION_PATH):
|
||||||
|
print("exporting fusion onnx model...")
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_SIM_PATH)
|
||||||
|
onnx_model = fusion(onnx_model)
|
||||||
|
onnx.save(onnx_model, ONNX_MODEL_FUSION_PATH, save_as_external_data=True, \
|
||||||
|
location="llama2_fusion_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||||
|
print("exporting fusion onnx model finished.")
|
||||||
|
|
||||||
|
if not args.is_1st_graph and not os.path.exists(ONNX_MODEL_SPECIAL_PATH):
|
||||||
|
print("exporting special onnx model...")
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_FUSION_PATH)
|
||||||
|
onnx_model = special_pass(onnx_model)
|
||||||
|
onnx.save(onnx_model, ONNX_MODEL_SPECIAL_PATH, save_as_external_data=True, \
|
||||||
|
location="llama2_special_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||||
|
print("exporting special onnx model finished.")
|
||||||
|
|
||||||
|
if not args.is_1st_graph and args.fp16 and not os.path.exists(ONNX_MODEL_FP16_PATH):
|
||||||
|
print("exporting fp16 onnx model...")
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_SPECIAL_PATH)
|
||||||
|
onnx_model = convert_to_fp16(onnx_model)
|
||||||
|
onnx.save(onnx_model, ONNX_MODEL_FP16_PATH, save_as_external_data=True, \
|
||||||
|
location="llama2_fp16_bs{}_layer{}.pb".format(args.batchsize, args.n_layers))
|
||||||
|
print("exporting fp16 onnx model finished.")
|
||||||
|
|
||||||
|
print("world_size =", args.world_size)
|
||||||
|
if not args.is_1st_graph and args.world_size > 1 and not os.path.exists(ONNX_MODEL_DIST_PATH):
|
||||||
|
print("exporting dist onnx model...")
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_FP16_PATH) if args.fp16 else onnx.load(ONNX_MODEL_SPECIAL_PATH)
|
||||||
|
onnx_model = parallel_model(onnx_model, args.world_size, args.rank)
|
||||||
|
onnx.save(onnx_model, ONNX_MODEL_DIST_PATH, save_as_external_data=True, \
|
||||||
|
location="llama2_dist_bs{}_layer{}_fp{}_worldsize{}_rank{}.pb".format(
|
||||||
|
args.batchsize, args.n_layers,
|
||||||
|
16 if args.fp16 else 32, args.world_size, args.rank))
|
||||||
|
print("exporting dist onnx model finished.")
|
||||||
|
|
||||||
|
def get_it_logit(onnx_model, input_ids):
|
||||||
|
# initialization
|
||||||
|
runtime = backend.CudaRuntime(args.rank)
|
||||||
|
runtime.init_comm(
|
||||||
|
"dist",
|
||||||
|
args.world_size,
|
||||||
|
args.rank,
|
||||||
|
)
|
||||||
|
print("[{}] comm init.".format(args.rank))
|
||||||
|
stub = OnnxStub(onnx_model, runtime)
|
||||||
|
print("[{}] stub init.".format(args.rank))
|
||||||
|
|
||||||
|
# warm up
|
||||||
|
for i in range(10):
|
||||||
|
if args.no_cudagraph:
|
||||||
|
stub.run()
|
||||||
|
else:
|
||||||
|
stub.run_with_cudagraph()
|
||||||
|
print("[{}] stub warmup.".format(args.rank))
|
||||||
|
|
||||||
|
logits = np.zeros((args.batchsize, args.n_max_length, args.vocab_size), dtype=np.float32)
|
||||||
|
output_ids = np.zeros((args.batchsize, args.n_max_length), dtype=np.int64)
|
||||||
|
avg_inference_time = 0
|
||||||
|
t0 = time.time()
|
||||||
|
for i in tqdm(range(0, args.n_max_length)):
|
||||||
|
with nvtx.annotate("seq_length = {}".format(i), color="red"):
|
||||||
|
assert input_ids.shape[0] == args.batchsize
|
||||||
|
input_id = input_ids[:, i] if i < input_ids.shape[1] else output_ids[:, i-1]
|
||||||
|
position_id = i*np.ones((args.batchsize, 1), dtype=np.int32)
|
||||||
|
|
||||||
|
# copyin input
|
||||||
|
with nvtx.annotate("[it] copyin", color="blue"):
|
||||||
|
(list(stub.inputs.items()))[0][1].copyin_int64(
|
||||||
|
input_id.reshape(-1).tolist())
|
||||||
|
(list(stub.inputs.items()))[1][1].copyin_int64(
|
||||||
|
position_id.reshape(-1).tolist())
|
||||||
|
|
||||||
|
# run
|
||||||
|
t10 = time.time()
|
||||||
|
with nvtx.annotate("[it] run", color="green"):
|
||||||
|
if args.no_cudagraph:
|
||||||
|
stub.run()
|
||||||
|
else:
|
||||||
|
stub.run_with_cudagraph()
|
||||||
|
t11 = time.time()
|
||||||
|
avg_inference_time += (t11 - t10)
|
||||||
|
|
||||||
|
# copyout output
|
||||||
|
if not args.speedup:
|
||||||
|
with nvtx.annotate("[it] copyout", color="blue"):
|
||||||
|
logits[:,i, :] = np.array((list(stub.outputs.items()))[0][1].copyout_float()).reshape(args.batchsize, -1)
|
||||||
|
output_ids[:, i] = np.argmax(logits[:, i, :], -1).astype(np.int64)
|
||||||
|
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
if args.rank == 0:
|
||||||
|
result = "[it] e2e: {} gpus, {} layers, e2e time: {:.2f}s, average inference time: {:.2f}ms"\
|
||||||
|
.format(args.num_nodes * args.nproc_per_node, args.n_layers, t1-t0, \
|
||||||
|
avg_inference_time*1000/args.n_max_length)
|
||||||
|
print(result)
|
||||||
|
del stub
|
||||||
|
return output_ids
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
PRETRAINED_LLAMA_PATH, num_hidden_layers=int(args.n_layers)).eval()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LLAMA_PATH)
|
||||||
|
#prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
#prompt = "What is PTX?"
|
||||||
|
#prompt = "Tell me a joke."
|
||||||
|
#prompt = "What are the key principles of smart investing?"
|
||||||
|
prompt = "What is DeepSpeed?"
|
||||||
|
prompts=[prompt]*args.batchsize
|
||||||
|
inputs = tokenizer(prompts, return_tensors="pt")
|
||||||
|
|
||||||
|
input_ids = inputs.input_ids
|
||||||
|
print("prompt ids =", input_ids)
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
# inference with InfiniTensor
|
||||||
|
##########################################################
|
||||||
|
print("exporting onnx...")
|
||||||
|
export_onnx(torch_model)
|
||||||
|
print("exporting onnx finished.")
|
||||||
|
|
||||||
|
onnx_to_run_path = ONNX_MODEL_DIST_PATH if args.world_size > 1 else \
|
||||||
|
(ONNX_MODEL_FP16_PATH if args.fp16 else ONNX_MODEL_SPECIAL_PATH)
|
||||||
|
print("loading onnx", onnx_to_run_path, "...")
|
||||||
|
onnx_model = onnx.load(onnx_to_run_path)
|
||||||
|
print("loading onnx finished.")
|
||||||
|
output_ids_it = get_it_logit(onnx_model, input_ids)
|
||||||
|
it_output_text = tokenizer.batch_decode(output_ids_it[:, input_ids.shape[-1]:output_ids_it.shape[-1]])
|
||||||
|
if args.rank == 0:
|
||||||
|
for i in range(args.batchsize):
|
||||||
|
print("prompt: ", prompts[i])
|
||||||
|
print("answer: [it]", it_output_text[i])
|
||||||
|
|
||||||
|
##########################################################
|
||||||
|
# validation with pytorch
|
||||||
|
##########################################################
|
||||||
|
"""
|
||||||
|
generate_ids = torch_model.generate(inputs.input_ids, max_length=args.n_max_length)#, num_beams=4, do_sample=True)
|
||||||
|
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"""
|
||||||
|
if not args.speedup and not args.is_1st_graph:
|
||||||
|
kvcache_torch = None
|
||||||
|
output_ids_pt = torch.zeros(args.batchsize, args.n_max_length).int() # + input_ids.shape[-1] - 1).int()
|
||||||
|
if args.fp16:
|
||||||
|
torch_model = torch_model.half()
|
||||||
|
|
||||||
|
torch_model = torch_model.cuda()
|
||||||
|
# print(torch.cuda.memory_summary())
|
||||||
|
|
||||||
|
avg_inference_time = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
t0 = time.time()
|
||||||
|
for i in range(args.n_max_length):
|
||||||
|
input_id = input_ids[:,i] if i < input_ids.shape[1] else out_token
|
||||||
|
input_id = input_id.view(args.batchsize,1).cuda()
|
||||||
|
t00 = time.time()
|
||||||
|
outputs = torch_model(input_id, past_key_values=kvcache_torch)
|
||||||
|
t01 = time.time()
|
||||||
|
avg_inference_time += (t01-t00)
|
||||||
|
|
||||||
|
logits = outputs['logits']
|
||||||
|
kvcache_torch = outputs['past_key_values']
|
||||||
|
out_token = torch.argmax(logits, dim=-1)
|
||||||
|
output_ids_pt[:, i:i+1] = out_token
|
||||||
|
t1 = time.time()
|
||||||
|
avg_inference_time /= args.n_max_length
|
||||||
|
result = "[pt] e2e time: {:.2f}s, average inference time: {:.2f}ms"\
|
||||||
|
.format(t1-t0, avg_inference_time*1000)
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
print(result)
|
||||||
|
pt_output_text = tokenizer.batch_decode(output_ids_pt[:,input_ids.shape[-1]:args.n_max_length])
|
||||||
|
for i in range(args.batchsize):
|
||||||
|
print("[pt]", args.rank, pt_output_text[i])
|
||||||
|
|
||||||
|
if not args.is_1st_graph:
|
||||||
|
assert(output_ids_it.shape[-1] == args.n_max_length)
|
||||||
|
np.testing.assert_equal(output_ids_pt[:, input_ids.shape[-1]:args.n_max_length], output_ids_it[:,input_ids.shape[-1]:args.n_max_length])
|
|
@ -3,14 +3,17 @@
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
struct AttentionKVCacheMetadata {
|
struct AttentionKVCacheMetadata {
|
||||||
int dimSize[4];
|
int head_dim;
|
||||||
int stride[4];
|
int num_heads;
|
||||||
|
int num_seqs;
|
||||||
|
int max_kv_seqlen;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
void attention_kvcache_kernel(int dType, void *input_k_cache,
|
||||||
float *input_q, float *input_k, float *input_v,
|
void *input_v_cache, void *input_q, void *input_k,
|
||||||
int *position_id, float *output_matmul,
|
void *input_v, int64_t *position_id,
|
||||||
|
void *output_matmul,
|
||||||
const AttentionKVCacheMetadata &compMeta,
|
const AttentionKVCacheMetadata &compMeta,
|
||||||
float *output_O_temp, float *output_sum_temp);
|
float *output_O_temp, float *output_sum_temp);
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,7 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
void rope_kernel(int dType, int *pos, void *input, void *output, int size,
|
void rope_kernel(int dType, int64_t *pos, void *input, void *output,
|
||||||
int dim_model, int dim_head, int hidden_stride,
|
int dim_model, int dim_head, int batchsize, int pos_stride);
|
||||||
int pos_stride);
|
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -21,7 +21,7 @@ class KUNLUNRuntimeObj : public RuntimeObj {
|
||||||
ctx = xdnn::create_context();
|
ctx = xdnn::create_context();
|
||||||
// 10GB for Longformer
|
// 10GB for Longformer
|
||||||
// size_t longformerNum = 3lu * (1 << 30);
|
// size_t longformerNum = 3lu * (1 << 30);
|
||||||
size_t workspaceSize = 2llu << 30; // 2 GB
|
size_t workspaceSize = 3llu << 30; // 3 GB
|
||||||
KUNLUNPtr wkspacePtr = alloc(workspaceSize);
|
KUNLUNPtr wkspacePtr = alloc(workspaceSize);
|
||||||
workspace =
|
workspace =
|
||||||
make_ref<WorkspaceObj<KUNLUNPtr>>(wkspacePtr, workspaceSize);
|
make_ref<WorkspaceObj<KUNLUNPtr>>(wkspacePtr, workspaceSize);
|
||||||
|
@ -42,7 +42,7 @@ class KUNLUNRuntimeObj : public RuntimeObj {
|
||||||
KUNLUNPtr alloc(size_t size) override {
|
KUNLUNPtr alloc(size_t size) override {
|
||||||
void *ptr;
|
void *ptr;
|
||||||
checkKUNLUNError(
|
checkKUNLUNError(
|
||||||
xpu_malloc((void **)&ptr, size, XPUMemoryKind::XPU_MEM_HBM));
|
xpu_malloc_ex((void **)&ptr, size, XPUMemoryKind::XPU_MEM_MAIN));
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
void dealloc(void *ptr) override { xpu_free(ptr); }
|
void dealloc(void *ptr) override { xpu_free(ptr); }
|
||||||
|
|
|
@ -34,8 +34,8 @@ class XcclCommunicatorObj final : public CommunicatorObj {
|
||||||
auto begin = std::chrono::steady_clock::now();
|
auto begin = std::chrono::steady_clock::now();
|
||||||
while (!std::filesystem::exists(filePath)) {
|
while (!std::filesystem::exists(filePath)) {
|
||||||
auto now = std::chrono::steady_clock::now();
|
auto now = std::chrono::steady_clock::now();
|
||||||
_IT_ASSERT_2(now < begin + std::chrono::seconds(100),
|
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
|
||||||
"time limit (100s) exceeded.");
|
"time limit (10s) exceeded.");
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
}
|
}
|
||||||
std::ifstream ifs(filePath, std::ios::binary);
|
std::ifstream ifs(filePath, std::ios::binary);
|
||||||
|
|
|
@ -3,8 +3,7 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
/**
|
/**
|
||||||
* @brief Fused Attention with KVCache input operator. All the input and output
|
* @brief Fused Attention with KVCache input operator.
|
||||||
* tensors should have the same rank except for the position_id.
|
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
class AttentionKVCacheObj : public OperatorObj {
|
class AttentionKVCacheObj : public OperatorObj {
|
||||||
|
@ -16,12 +15,19 @@ class AttentionKVCacheObj : public OperatorObj {
|
||||||
*
|
*
|
||||||
* @param graph The computation graph that this operator belongs to.
|
* @param graph The computation graph that this operator belongs to.
|
||||||
* @param input_k_cache The k_cache input tensor.
|
* @param input_k_cache The k_cache input tensor.
|
||||||
|
* Shape: [batchsize, num_heads, k_cache_seq_length, head_dim]
|
||||||
* @param input_v_cache The v_cache input tensor.
|
* @param input_v_cache The v_cache input tensor.
|
||||||
|
* Shape: [batchsize, num_heads, v_cache_seq_length, head_dim]
|
||||||
* @param input_q The query input tensor.
|
* @param input_q The query input tensor.
|
||||||
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
* @param input_k The key input tensor.
|
* @param input_k The key input tensor.
|
||||||
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
* @param input_v The value input tensor.
|
* @param input_v The value input tensor.
|
||||||
* @param position_id The positon id of the query,
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
|
* @param position_id The positon id of the query.
|
||||||
|
* Shape: [batchsize, q_seq_length]
|
||||||
* @param output_matmul The query output tensor.
|
* @param output_matmul The query output tensor.
|
||||||
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
*/
|
*/
|
||||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||||
|
@ -30,6 +36,10 @@ class AttentionKVCacheObj : public OperatorObj {
|
||||||
OP_CLONE(AttentionKVCacheObj);
|
OP_CLONE(AttentionKVCacheObj);
|
||||||
|
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override {
|
||||||
|
return {inputs[2]->getDType()};
|
||||||
|
};
|
||||||
|
DataType getDType() const { return getInputs(2)->getDType(); }
|
||||||
|
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
int numInputs() const override { return 6; }
|
int numInputs() const override { return 6; }
|
||||||
|
|
|
@ -21,6 +21,10 @@ class RoPEObj : public OperatorObj {
|
||||||
int numOutputs() const override { return 1; }
|
int numOutputs() const override { return 1; }
|
||||||
DataType getDType() const { return getInputs(1)->getDType(); }
|
DataType getDType() const { return getInputs(1)->getDType(); }
|
||||||
|
|
||||||
|
vector<DataType> inferDataType(const TensorVec &inputs) const override {
|
||||||
|
return {inputs[1]->getDType()};
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
|
|
|
@ -208,16 +208,39 @@ class OnnxStub:
|
||||||
op[1],
|
op[1],
|
||||||
)
|
)
|
||||||
elif node.op_type == "MatMul":
|
elif node.op_type == "MatMul":
|
||||||
tensors[node.output[0]] = self.handler.matmul(
|
if node.input[1] in data.keys() \
|
||||||
tensors[node.input[0]], # input
|
and to_array(data[node.input[1]]).dtype == np.float32 \
|
||||||
tensors[node.input[1]], # weight
|
and 'cuda_runtime' in dir(backend) \
|
||||||
tensors.get(node.output[0]),
|
and tensors[node.input[0]].shape()[0] == 1 \
|
||||||
False,
|
and tensors[node.input[0]].shape()[1] == 1 \
|
||||||
False,
|
and len(tensors[node.input[1]].shape()) == 2 \
|
||||||
None,
|
and node.input[1] in data.keys():
|
||||||
backend.ActType.Linear,
|
data[node.input[1]] = from_array(
|
||||||
matmul_compute_type,
|
np.transpose(to_array(data[node.input[1]])))
|
||||||
)
|
tensors[node.input[1]] = self.handler.tensor(
|
||||||
|
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
|
||||||
|
tensors[node.input[1]].dtype())
|
||||||
|
tensors[node.output[0]] = self.handler.matmul(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
backend.ActType.Linear,
|
||||||
|
matmul_compute_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensors[node.output[0]] = self.handler.matmul(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
backend.ActType.Linear,
|
||||||
|
matmul_compute_type,
|
||||||
|
)
|
||||||
elif node.op_type == "Gemm":
|
elif node.op_type == "Gemm":
|
||||||
attributes = _parse_attribute(
|
attributes = _parse_attribute(
|
||||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||||
|
@ -967,7 +990,7 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type in ["Constant", "ConstantOfShape"]:
|
elif node.op_type == "Constant":
|
||||||
output_name = node.output[0]
|
output_name = node.output[0]
|
||||||
attributes = _parse_attribute(node)
|
attributes = _parse_attribute(node)
|
||||||
tensor = attributes["value"]
|
tensor = attributes["value"]
|
||||||
|
|
|
@ -199,24 +199,6 @@ class CastCnnl : public BangKernelWithoutConfig {
|
||||||
dim.data()));
|
dim.data()));
|
||||||
NlCastType = CNNL_CAST_UINT32_TO_INT64;
|
NlCastType = CNNL_CAST_UINT32_TO_INT64;
|
||||||
break;
|
break;
|
||||||
case CastType::Float162Float:
|
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
|
||||||
CNNL_DTYPE_HALF, dim.size(),
|
|
||||||
dim.data()));
|
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
|
||||||
CNNL_DTYPE_FLOAT, dim.size(),
|
|
||||||
dim.data()));
|
|
||||||
NlCastType = CNNL_CAST_HALF_TO_FLOAT;
|
|
||||||
break;
|
|
||||||
case CastType::Float2Float16:
|
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
|
||||||
CNNL_DTYPE_FLOAT, dim.size(),
|
|
||||||
dim.data()));
|
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
|
||||||
CNNL_DTYPE_HALF, dim.size(),
|
|
||||||
dim.data()));
|
|
||||||
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,16 +19,14 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
||||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
auto inDims = op->getInputs(0)->getDims();
|
auto inDims = op->getInputs(0)->getDims();
|
||||||
auto fiterDims = op->getInputs(1)->getDims();
|
|
||||||
auto outDims = op->getOutput()->getDims();
|
auto outDims = op->getOutput()->getDims();
|
||||||
|
auto fiterDims = op->getOutput(1)->getDims();
|
||||||
|
|
||||||
float eps = op->getEps();
|
float eps = op->getEps();
|
||||||
const int axis = op->getAxis();
|
const int axis = op->getAxis();
|
||||||
|
|
||||||
Shape outMeanDims(outDims);
|
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc;
|
||||||
outMeanDims.erase(outMeanDims.begin() + axis);
|
|
||||||
|
|
||||||
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc, outMeanDesc;
|
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
@ -41,23 +39,15 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||||
outDims.size(), outDims.data()));
|
outDims.size(), outDims.data()));
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&outMeanDesc));
|
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
|
||||||
outMeanDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
|
||||||
outMeanDims.size(), outMeanDims.data()));
|
|
||||||
size_t wsSize;
|
size_t wsSize;
|
||||||
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
|
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
|
||||||
&wsSize);
|
&wsSize);
|
||||||
BangPtr wsData = context->getWorkspace(wsSize);
|
BangPtr wsData = context->getWorkspace(wsSize);
|
||||||
size_t meanSize =
|
|
||||||
cnnlGetTensorElementNum(outMeanDesc) * op->getDType().getSize();
|
|
||||||
BangPtr meanData = context->getWorkspace(meanSize);
|
|
||||||
BangPtr rstdData = context->getWorkspace(meanSize);
|
|
||||||
|
|
||||||
cnnlStatus_t stat = cnnlLayerNormForward(
|
cnnlStatus_t stat = cnnlLayerNormForward(
|
||||||
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
|
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
|
||||||
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
|
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
|
||||||
outMeanDesc, meanData, rstdData);
|
inDesc, NULL, NULL);
|
||||||
|
|
||||||
if (stat != CNNL_STATUS_SUCCESS)
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -66,13 +66,6 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
||||||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
|
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
|
||||||
sizeof(int32_t));
|
sizeof(int32_t));
|
||||||
|
|
||||||
std::string computeTypeStr = op->getComputeType();
|
|
||||||
if (computeTypeStr == "tf32") {
|
|
||||||
int32_t tf32 = 1;
|
|
||||||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_ALLOW_TF32, &tf32,
|
|
||||||
sizeof(int32_t));
|
|
||||||
}
|
|
||||||
|
|
||||||
cnnlMatMulAlgo_t bmm_algo;
|
cnnlMatMulAlgo_t bmm_algo;
|
||||||
cnnlMatMulAlgoCreate(&bmm_algo);
|
cnnlMatMulAlgoCreate(&bmm_algo);
|
||||||
|
|
||||||
|
|
|
@ -7,33 +7,37 @@ namespace infini {
|
||||||
|
|
||||||
class AttentionKVCacheCompute {
|
class AttentionKVCacheCompute {
|
||||||
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
|
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
|
||||||
Tensor tensor) const {
|
Tensor input_v_cache,
|
||||||
int nDims = tensor->getRank();
|
Tensor position_id) const {
|
||||||
auto strides = tensor->getStride();
|
int nDims = input_v_cache->getRank();
|
||||||
|
auto strides = input_v_cache->getStride();
|
||||||
IT_ASSERT(nDims == 4);
|
IT_ASSERT(nDims == 4);
|
||||||
IT_ASSERT(strides.size() == (size_t)nDims);
|
int dim_position_id = position_id->getRank();
|
||||||
for (int i = 0; i < nDims; ++i) {
|
metadata.num_seqs = 1;
|
||||||
metadata.dimSize[i] = tensor->getDims().at(i);
|
for (int i = 0; i < dim_position_id; i++) {
|
||||||
metadata.stride[i] = strides.at(i);
|
metadata.num_seqs *= position_id->getDims().at(i);
|
||||||
}
|
}
|
||||||
|
metadata.head_dim = input_v_cache->getDims().at(3);
|
||||||
|
metadata.num_heads = input_v_cache->getDims().at(1);
|
||||||
|
metadata.max_kv_seqlen = input_v_cache->getDims().at(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
void do_compute(int dType, Tensor input_k_cache, Tensor input_v_cache,
|
||||||
Tensor input_k, Tensor input_v, Tensor position_id,
|
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||||
Tensor output_matmul, CudaPtr p_workspace) const {
|
Tensor position_id, Tensor output_matmul,
|
||||||
|
CudaPtr p_workspace) const {
|
||||||
AttentionKVCacheMetadata metadata;
|
AttentionKVCacheMetadata metadata;
|
||||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
initAttentionKVCacheMetadata(metadata, input_v_cache, position_id);
|
||||||
|
|
||||||
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
|
attention_kvcache_kernel(
|
||||||
input_v_cache->getRawDataPtr<float *>(),
|
dType, input_k_cache->getRawDataPtr<void *>(),
|
||||||
input_q->getRawDataPtr<float *>(),
|
input_v_cache->getRawDataPtr<void *>(),
|
||||||
input_k->getRawDataPtr<float *>(),
|
input_q->getRawDataPtr<void *>(), input_k->getRawDataPtr<void *>(),
|
||||||
input_v->getRawDataPtr<float *>(),
|
input_v->getRawDataPtr<void *>(),
|
||||||
position_id->getRawDataPtr<int *>(),
|
position_id->getRawDataPtr<int64_t *>(),
|
||||||
output_matmul->getRawDataPtr<float *>(),
|
output_matmul->getRawDataPtr<void *>(), metadata,
|
||||||
metadata, (float *)p_workspace,
|
(float *)p_workspace, (float *)(p_workspace + (1ll << 30)));
|
||||||
(float *)(p_workspace + (1ll << 30)));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -41,15 +45,17 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
||||||
public CudaKernelWithoutConfig {
|
public CudaKernelWithoutConfig {
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
IT_ASSERT(_op->getDType() == DataType::Float32);
|
auto op = as<AttentionKVCacheObj>(_op);
|
||||||
|
int dType = op->getDType().getIndex();
|
||||||
|
int position_idx_dtype = op->getInputs()[5]->getDTypeIndex();
|
||||||
|
IT_ASSERT(dType == 1 || dType == 10 || position_idx_dtype == 7);
|
||||||
|
|
||||||
size_t workspaceSize = 2ll << 30;
|
size_t workspaceSize = 2ll << 30;
|
||||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
|
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
|
||||||
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
do_compute(dType, op->getInputs()[0], op->getInputs()[1],
|
||||||
_op->getInputs()[2], _op->getInputs()[3],
|
op->getInputs()[2], op->getInputs()[3], op->getInputs()[4],
|
||||||
_op->getInputs()[4], _op->getInputs()[5],
|
op->getInputs()[5], op->getOutputs()[0], idxWsData);
|
||||||
_op->getOutputs()[0], idxWsData);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,171 +1,236 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include "cuda/cuda_attention_kvcache.h"
|
#include "cuda/cuda_attention_kvcache.h"
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
#define BLOCKSIZE WARP_SIZE
|
|
||||||
#define SEQ_UNIT 16
|
#define SEQ_UNIT 16
|
||||||
|
#define BLOCKSIZE_2 WARP_SIZE*4
|
||||||
|
#define MAX_PARTITION 1024
|
||||||
|
|
||||||
// ASSUME SEQ_LEN OF Q IS 1
|
template <class T>
|
||||||
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
__global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
|
||||||
float* input_v_cache,
|
T* input_v_cache,
|
||||||
float* input_q,
|
T* input_q,
|
||||||
float* input_k,
|
T* input_k,
|
||||||
float* input_v,
|
T* input_v,
|
||||||
int* position_id,
|
int64_t* position_id,
|
||||||
AttentionKVCacheMetadata compMeta,
|
AttentionKVCacheMetadata compMeta,
|
||||||
float* output_O_temp,
|
half* output_O_temp,
|
||||||
float* output_sum_temp) {
|
float* output_sum_temp) {
|
||||||
int seq_length = position_id[0] + 1;
|
int seq_length = position_id[blockIdx.y] + 1;
|
||||||
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
|
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
|
||||||
if(blockIdx.y >= stride)
|
if(blockIdx.z >= stride)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
|
||||||
int group_id = threadIdx.x / WARP_SIZE;
|
int parallel_idx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
|
||||||
int idx_seq = blockIdx.y * SEQ_UNIT;
|
|
||||||
|
|
||||||
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
int idx_seq = blockIdx.z * SEQ_UNIT;
|
||||||
return;
|
|
||||||
|
|
||||||
float ptr_V[SEQ_UNIT*4]; // V
|
half reg_V[4];
|
||||||
float ptr_K[SEQ_UNIT*4]; // K
|
half reg_K[4];
|
||||||
float ptr_Q[4]; // Q
|
half reg_Q[4];
|
||||||
float ptr_P[SEQ_UNIT] = {0};
|
float reg_P;
|
||||||
|
|
||||||
float ptr_O[4] = {0};
|
float reg_O[4] = {0};
|
||||||
float ptr_sum[1] = {0};
|
float reg_sum = 0;
|
||||||
|
float temp[4];
|
||||||
|
bool is_fp16 = sizeof(T) == 2 ? true : false;
|
||||||
|
|
||||||
|
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.head_dim;
|
||||||
|
|
||||||
// readin Q
|
// readin Q
|
||||||
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
|
if(!is_fp16){
|
||||||
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
|
|
||||||
|
|
||||||
// Q*K
|
|
||||||
#pragma unroll
|
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
|
||||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
|
||||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
|
||||||
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
|
||||||
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
|
||||||
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
|
||||||
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i ++){
|
for(int i = 0; i < 4; i += 2){
|
||||||
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
|
(float2 &)temp[i] = (float2 &)input_q[idx_qkv + i*WARP_SIZE];
|
||||||
#pragma unroll
|
*((half2*)(®_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||||
for (int offset = 16; offset > 0; offset /= 2) {
|
|
||||||
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
|
|
||||||
}
|
|
||||||
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else{
|
||||||
// div sqrt(d)
|
#pragma unroll
|
||||||
#pragma unroll
|
for(int i = 0; i < 4; i += 2){
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
(half2 &)reg_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
|
||||||
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
}
|
||||||
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
|
|
||||||
}
|
}
|
||||||
|
int common_idx = lane_id_x2 + (parallel_idx * compMeta.max_kv_seqlen * compMeta.head_dim);
|
||||||
// softmax
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||||
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
|
reg_P = 0;
|
||||||
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.head_dim);
|
||||||
}
|
// readin K & V
|
||||||
|
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||||
// * V
|
#pragma unroll
|
||||||
#pragma unroll
|
for(int i = 0; i < 4; i += 2){
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
*((half2*)(®_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
*((half2*)(®_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||||
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
}
|
||||||
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
if(!is_fp16){
|
||||||
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
#pragma unroll
|
||||||
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
|
for(int i = 0; i < 4; i += 2){
|
||||||
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
|
(float2 &)temp[i] = (float2 &) input_k[idx_qkv + i*WARP_SIZE];
|
||||||
|
*((half2*)(®_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||||
|
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_K[i]));
|
||||||
|
(float2 &)temp[i] = (float2 &) input_v[idx_qkv + i*WARP_SIZE];
|
||||||
|
*((half2*)(®_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||||
|
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_V[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
#pragma unroll
|
||||||
|
for(int i = 0; i < 4; i += 2){
|
||||||
|
(half2 &)reg_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
|
||||||
|
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_K[i]));
|
||||||
|
(half2 &)reg_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
|
||||||
|
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_V[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Q*K
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i += 2){
|
||||||
|
(half2 &)reg_K[i] = (half2 &)reg_Q[i] * (half2 &)reg_K[i];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
|
||||||
|
(half2 &)reg_K[i] += __shfl_xor_sync(0xffffffff, (half2 &)reg_K[i], offset);
|
||||||
|
}
|
||||||
|
(float2 &) temp[i] = __half22float2((half2 &)reg_K[i]);
|
||||||
|
reg_P += (temp[i] + temp[i+1]);
|
||||||
|
(float2 &) temp[i] = __half22float2((half2 &)reg_V[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// div sqrt(d)
|
||||||
|
reg_P /= sqrt(128.0);
|
||||||
|
|
||||||
|
// softmax
|
||||||
|
reg_P = expf(reg_P);
|
||||||
|
reg_sum += reg_P;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i ++)
|
for (int i = 0; i < 4; i ++)
|
||||||
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
|
reg_O[i] = fmaf(reg_P, temp[i], reg_O[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i ++)
|
for (int i = 0; i < 4; i ++)
|
||||||
ptr_O[i] /= ptr_sum[0];
|
reg_O[i] /= reg_sum;
|
||||||
|
|
||||||
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
|
#pragma unroll
|
||||||
if(lane_id == 0){
|
for(int i = 0; i < 4; i += 2)
|
||||||
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
|
(half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.z * compMeta.head_dim) + (parallel_idx * compMeta.head_dim * stride)] = __float22half2_rn((float2 &)reg_O[i]);
|
||||||
|
if(lane_id_x2 == 0){
|
||||||
|
output_sum_temp[blockIdx.z + parallel_idx * stride] = reg_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|
||||||
float* output_matmul,
|
template <class T>
|
||||||
|
__global__ void _attention_kvcache_kernel_128_2(int64_t* position_id,
|
||||||
|
T* output_matmul,
|
||||||
AttentionKVCacheMetadata compMeta,
|
AttentionKVCacheMetadata compMeta,
|
||||||
float* output_O_temp,
|
half* output_O_temp,
|
||||||
float* output_sum_temp) {
|
float* output_sum_temp) {
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
int group_id = threadIdx.x / WARP_SIZE;
|
int parallel_idx = blockIdx.x;
|
||||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
int offset = parallel_idx * compMeta.head_dim;
|
||||||
|
|
||||||
|
|
||||||
float ptr_O[4] = {0};
|
|
||||||
float ptr_O_sum[4] = {0};
|
|
||||||
float ptr_sum = 0;
|
|
||||||
float ptr_sum_temp;
|
|
||||||
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
||||||
|
bool is_fp16 = sizeof(T) == 2 ? true : false;
|
||||||
|
|
||||||
|
if(size == 1){
|
||||||
|
if(!is_fp16){
|
||||||
|
#pragma unroll
|
||||||
|
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
|
||||||
|
output_matmul[i + offset]
|
||||||
|
= __half2float(output_O_temp[i + offset]);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
#pragma unroll
|
||||||
|
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
|
||||||
|
output_matmul[i + offset]
|
||||||
|
= output_O_temp[i + offset];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
__shared__ float shm_sum_temp[MAX_PARTITION];
|
||||||
|
__shared__ float shm_sum[WARP_SIZE];
|
||||||
|
float temp_sum = 0;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int i = 0; i < size; i ++){
|
for(int i = threadIdx.x; i < size; i += blockDim.x){
|
||||||
(float4 &)ptr_O[0]
|
shm_sum_temp[i] = output_sum_temp[i + parallel_idx * size];
|
||||||
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
temp_sum += shm_sum_temp[i];
|
||||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for(int k = 0; k < 4; k ++)
|
|
||||||
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
|
|
||||||
ptr_sum += ptr_sum_temp;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int k = 0; k < 4; k ++)
|
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
|
||||||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
temp_sum += __shfl_down_sync(0xffffffff, temp_sum, offset);
|
||||||
|
if(lane_id == 0)
|
||||||
|
shm_sum[threadIdx.x/WARP_SIZE] = temp_sum;
|
||||||
|
__syncthreads();
|
||||||
|
temp_sum = lane_id < (size + WARP_SIZE - 1) / WARP_SIZE ? shm_sum[lane_id] : 0;
|
||||||
|
|
||||||
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
#pragma unroll
|
||||||
|
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
|
||||||
|
temp_sum += __shfl_xor_sync(0xffffffff, temp_sum, offset);
|
||||||
|
temp_sum = __fdividef(1.0f, temp_sum + 1e-6f);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x){
|
||||||
|
float acc = 0.0f;
|
||||||
|
for(int j = 0; j < size; j ++){
|
||||||
|
acc = fma(__half2float(output_O_temp[i + (j * compMeta.head_dim) + offset * size]) * shm_sum_temp[j], temp_sum, acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!is_fp16){
|
||||||
|
output_matmul[i + offset] = acc;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
output_matmul[i + offset] = __float2half(acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cache,
|
||||||
float *input_q, float *input_k,
|
void *input_q, void *input_k,
|
||||||
float *input_v, int *position_id, float *output_matmul,
|
void *input_v, int64_t *position_id, void *output_matmul,
|
||||||
const AttentionKVCacheMetadata &compMeta,
|
const AttentionKVCacheMetadata &compMeta,
|
||||||
float *output_O_temp, float *output_sum_temp) {
|
float *output_O_temp, float *output_sum_temp) {
|
||||||
IT_ASSERT(compMeta.dimSize[3] == 128);
|
IT_ASSERT(dType == 1 || dType == 10);
|
||||||
|
|
||||||
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
|
int gridsize_y = (compMeta.max_kv_seqlen - 1 + SEQ_UNIT) / SEQ_UNIT;
|
||||||
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
|
dim3 gridDim(compMeta.num_heads, compMeta.num_seqs, gridsize_y);
|
||||||
dim3 blockDim(BLOCKSIZE, 1);
|
dim3 blockDim(WARP_SIZE, 1);
|
||||||
|
|
||||||
_attention_kvcache_kernel_128_1
|
if(dType == 1){
|
||||||
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
_attention_kvcache_kernel_128_1<float>
|
||||||
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
|
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||||
compMeta, output_O_temp, output_sum_temp);
|
((float*)input_k_cache, (float*)input_v_cache, (float*)input_q, (float*)input_k, (float*)input_v,
|
||||||
|
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||||
|
|
||||||
|
_attention_kvcache_kernel_128_2<float>
|
||||||
|
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
|
||||||
|
0, CUDAStream::getCurrentStream()>>>
|
||||||
|
(position_id, (float*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
_attention_kvcache_kernel_128_1<half>
|
||||||
|
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||||
|
((half*)input_k_cache, (half*)input_v_cache, (half*)input_q, (half*)input_k, (half*)input_v,
|
||||||
|
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||||
|
|
||||||
|
_attention_kvcache_kernel_128_2<half>
|
||||||
|
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
|
||||||
|
0, CUDAStream::getCurrentStream()>>>
|
||||||
|
(position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||||
|
}
|
||||||
|
|
||||||
_attention_kvcache_kernel_128_2
|
|
||||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
|
|
||||||
0, CUDAStream::getCurrentStream()>>>
|
|
||||||
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -36,7 +36,7 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
||||||
|
|
||||||
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
||||||
if (cuDataType == CUDA_R_16F) {
|
if (cuDataType == CUDA_R_16F) {
|
||||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
return CUBLAS_COMPUTE_16F;
|
||||||
} else if (cuDataType == CUDA_R_16BF) {
|
} else if (cuDataType == CUDA_R_16BF) {
|
||||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||||
} else if (cuDataType == CUDA_R_32F) {
|
} else if (cuDataType == CUDA_R_32F) {
|
||||||
|
|
|
@ -18,17 +18,18 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
||||||
const auto &inputShape = input->getDims();
|
const auto &inputShape = input->getDims();
|
||||||
int nDims = input->getDims().size();
|
int nDims = input->getDims().size();
|
||||||
|
|
||||||
int size = input->size();
|
|
||||||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
IT_ASSERT(inputShape[0] == pos->getDims()[0] &&
|
||||||
|
inputShape[1] == pos->getDims()[1]);
|
||||||
|
int position_idx_dtype = op->getInputs()[0]->getDTypeIndex();
|
||||||
int dim_model = inputShape[2];
|
int dim_model = inputShape[2];
|
||||||
int dim_head = 128;
|
int dim_head = 128; // TODO: get dim_head from the framework
|
||||||
int hidden_stride = dim_model * inputShape[1];
|
|
||||||
int pos_stride = inputShape[1];
|
int pos_stride = inputShape[1];
|
||||||
|
int batchsize = inputShape[0];
|
||||||
|
|
||||||
const int dType = op->getDType().getIndex();
|
const int dType = op->getDType().getIndex();
|
||||||
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
|
rope_kernel(dType, pos->getRawDataPtr<int64_t *>(), inputData,
|
||||||
size, dim_model, dim_head, hidden_stride, pos_stride);
|
outputData, dim_model, dim_head, batchsize, pos_stride);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,15 @@
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
__global__ void _rope_kernel(int64_t* pos, void *in, void *out, int dim_model,
|
||||||
int dim_head, int hidden_stride, int pos_stride) {
|
int dim_head, int batchsize, int pos_stride) {
|
||||||
int batch_id = blockIdx.x;
|
int batch_id = blockIdx.x;
|
||||||
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
|
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
|
||||||
|
|
||||||
int ith = blockIdx.z * blockDim.x + threadIdx.x;
|
int ith = blockIdx.z * blockDim.x + threadIdx.x;
|
||||||
int col = ith % dim_head;
|
int col = ith % dim_head;
|
||||||
int offset = batch_id * hidden_stride + blockIdx.y * dim_model;
|
int batch_stride = pos_stride * dim_model;
|
||||||
|
int offset = batch_id * batch_stride + blockIdx.y * dim_model;
|
||||||
|
|
||||||
if (ith >= dim_model)
|
if (ith >= dim_model)
|
||||||
return;
|
return;
|
||||||
|
@ -34,7 +36,7 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
||||||
#define CASE(T) \
|
#define CASE(T) \
|
||||||
_rope_kernel<DT_CUDA<T>::t> \
|
_rope_kernel<DT_CUDA<T>::t> \
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||||
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
|
(pos, input, output, dim_model, dim_head, batchsize, pos_stride);
|
||||||
|
|
||||||
#define SWITCH_DTYPE(DTYPE) \
|
#define SWITCH_DTYPE(DTYPE) \
|
||||||
switch (DTYPE) { \
|
switch (DTYPE) { \
|
||||||
|
@ -79,10 +81,10 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
void rope_kernel(int dType, int64_t * pos, void *input, void *output,
|
||||||
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
int dim_model, int dim_head, int batchsize, int pos_stride) {
|
||||||
dim3 blocksize = dim3(32,1,1);
|
dim3 blocksize = dim3(32,1,1);
|
||||||
dim3 gridsize = dim3(1, 1, dim_model/32);
|
dim3 gridsize = dim3(batchsize, pos_stride, dim_model/32);
|
||||||
SWITCH_DTYPE(dType)
|
SWITCH_DTYPE(dType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -97,14 +97,11 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
auto aDim = op->getInputs(0)->getDims();
|
auto aDim = op->getInputs(0)->getDims();
|
||||||
auto bSize = op->getInputs(1)->size();
|
auto bSize = op->getInputs(1)->size();
|
||||||
auto bDim = op->getInputs(1)->getDims();
|
auto bDim = op->getInputs(1)->getDims();
|
||||||
|
auto dtype = op->getDType();
|
||||||
|
|
||||||
// op input a, b is scalar while aDim and b Dim is empty
|
|
||||||
if (bDim.size() == 0) {
|
if (bDim.size() == 0) {
|
||||||
bDim.push_back(1);
|
bDim.push_back(1);
|
||||||
}
|
}
|
||||||
if (aDim.size() == 0) {
|
|
||||||
aDim.push_back(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (aSize == bSize) {
|
if (aSize == bSize) {
|
||||||
// Do ElementWise Sub with no broadcast
|
// Do ElementWise Sub with no broadcast
|
||||||
|
@ -112,9 +109,23 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
||||||
(float *)aData, (float *)bData,
|
(float *)aData, (float *)bData,
|
||||||
(float *)cData, aSize));
|
(float *)cData, aSize));
|
||||||
} else {
|
} else {
|
||||||
checkKUNLUNError(xdnn::broadcast_div<float>(
|
// Do broadcast div
|
||||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
Shape aligned = infer_broadcast(aDim, bDim);
|
||||||
(float *)cData, aDim, bDim));
|
if (aligned == aDim) {
|
||||||
|
// BData need to be broadcasted
|
||||||
|
checkKUNLUNError(xdnn::broadcast_div<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||||
|
(float *)cData, aDim, bDim));
|
||||||
|
} else {
|
||||||
|
// Use workspace to broadcast aData
|
||||||
|
KUNLUNPtr wks = context->getWorkspace(bSize * dtype.getSize());
|
||||||
|
checkKUNLUNError(xdnn::broadcast<float>(
|
||||||
|
context->KUNLUNHandle(), (float *)aData, (float *)wks, aDim,
|
||||||
|
bDim));
|
||||||
|
checkKUNLUNError(xdnn::div<float>(context->KUNLUNHandle(),
|
||||||
|
(float *)wks, (float *)bData,
|
||||||
|
(float *)cData, bSize));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -570,7 +570,6 @@ REGISTER_KERNEL(Device::KUNLUN, OpType::Reciprocal, ReciprocalXdnn,
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, CopyXdnn, "Reshape_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, CopyXdnn, "Reshape_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, CopyXdnn, "Flatten_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, CopyXdnn, "Flatten_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, CopyXdnn, "Identity_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, CopyXdnn, "Identity_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Squeeze, CopyXdnn, "Squeeze_xdnn");
|
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, AbsXdnn, "Abs_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, AbsXdnn, "Abs_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, ATanXdnn, "Atan_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, ATanXdnn, "Atan_xdnn");
|
||||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Log, LogXdnn, "Log_xdnn");
|
REGISTER_KERNEL(Device::KUNLUN, OpType::Log, LogXdnn, "Log_xdnn");
|
||||||
|
|
|
@ -26,11 +26,12 @@ TEST(RoPE, Cuda) {
|
||||||
cudaRuntime->run(gCuda);
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
|
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
|
||||||
|
oCpu->printData();
|
||||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
0.540302, 0.647906, 0.731761, 0.796458, 0.846009, 0.883756, 0.912396,
|
||||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
0.934062, 0.950415, 0.962739, 0.972014, 0.978989, 0.98423, 0.988167,
|
||||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
0.991122, 0.99334, 0.995004, 0.996253, 0.99719, 0.997892, 0.998419,
|
||||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
0.998815, 0.999111, 0.999333, 0.9995, 0.999625, 0.999719, 0.999789,
|
||||||
1.381773, 1.381773, 1.381773, 1.381773}));
|
0.999842, 0.999881, 0.999911, 0.999933}));
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue