forked from jiuyuan/InfiniTensor
Compare commits
8 Commits
dist_bench
...
master
Author | SHA1 | Date |
---|---|---|
![]() |
5559536470 | |
![]() |
fac28c25f6 | |
![]() |
985d0dee5f | |
![]() |
d1de3ab5c2 | |
![]() |
eafbff6cf9 | |
![]() |
7f6aec6c17 | |
![]() |
a98573990b | |
![]() |
54a35772fb |
|
@ -13,3 +13,6 @@
|
|||
[submodule "example"]
|
||||
path = examples/NNmodel
|
||||
url = git@github.com:wanghailu0717/NNmodel.git
|
||||
[submodule "examples/distributed/onnxsim_large_model"]
|
||||
path = examples/distributed/onnxsim_large_model
|
||||
url = git@github.com:luchangli03/onnxsim_large_model.git
|
||||
|
|
|
@ -285,8 +285,8 @@ if(USE_KUNLUN)
|
|||
message(STATUS "KUNLUN_HOME: ${KUNLUN_HOME}")
|
||||
|
||||
include_directories("${KUNLUN_HOME}/include/")
|
||||
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/so/")
|
||||
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/so/")
|
||||
find_library(KUNLUN_RT libxpurt.so "${KUNLUN_HOME}/lib64/")
|
||||
find_library(KUNLUN_DNN libxpuapi.so "${KUNLUN_HOME}/lib64/")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||
|
||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||
|
|
|
@ -9,7 +9,7 @@ find_path(XCCL_INCLUDE_DIRS # ${XCCL_INCLUDE_DIR}
|
|||
HINTS XCCL_INCLUDE_DIR)
|
||||
|
||||
find_library(XCCL_LIBRARIES # ${XCCL_LIB_DIR}
|
||||
NAMES so/libbkcl.so
|
||||
NAMES lib64/libbkcl.so
|
||||
HINTS XCCL_LIB_DIR)
|
||||
|
||||
message(STATUS "XCCL_INCLUDE_DIRS: ${XCCL_INCLUDE_DIRS}")
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# 分布式脚本
|
||||
|
||||
## 英伟达平台运行方式
|
||||
|
||||
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
||||
|
||||
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
||||
|
||||
```bash
|
||||
python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
|
||||
```
|
||||
|
||||
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
|
||||
|
||||
#### 2. 运行InfiniTensor分布式脚本
|
||||
|
||||
```bash
|
||||
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
||||
```
|
||||
|
||||
## 寒武纪平台运行方式
|
||||
|
||||
**将上述运行脚本 `run_pytorch.py` 以及 `cuda_launch.py` 针对寒武纪平台做了相应的适配,具体见 `run_pytorch_mlu.py` 以及 `bang_launch.py`。**
|
||||
|
||||
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
||||
|
||||
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
||||
|
||||
```bash
|
||||
python run_pytorch_mlu.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
|
||||
```
|
||||
|
||||
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
|
||||
|
||||
#### 2. 运行InfiniTensor分布式脚本
|
||||
|
||||
```bash
|
||||
python bang_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
||||
```
|
|
@ -1,35 +1,39 @@
|
|||
import sys
|
||||
sys.path.append('../')
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="test", help="name of this instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
|
||||
help="path to the ONNX model file."
|
||||
"--model", type=str, required=True, help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
|
@ -40,39 +44,46 @@ def parse_args():
|
|||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
args.type,
|
||||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
def run_model(model, runtime, world_size=1, rank=0, n=10, data_type="default"):
|
||||
stub = OnnxStub(model, runtime, matmul_compute_type=data_type)
|
||||
load_inputs(stub, world_size, rank)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
# get outputs
|
||||
time.sleep(0.01)
|
||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
begin = time.time()
|
||||
for _ in range(n * 2):
|
||||
stub.run()
|
||||
end = time.time()
|
||||
avg_time = (end - begin) / n
|
||||
avg_time = (end - begin) / (n * 2)
|
||||
print(f"average time: {avg_time}")
|
||||
return outputs
|
||||
|
||||
def load_inputs(stub, world_size=1, rank=0):
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = np.load(f"./data/input_{i}.npy")
|
||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||
tensor.copyin_numpy(input)
|
||||
else:
|
||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||
|
||||
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||
|
||||
def run_and_compare(name, model, runtime, world_size=1, rank=0, data_type="default"):
|
||||
results = np.load(f"./data/output.npy")
|
||||
outputs = run_model(model, runtime, world_size, rank)
|
||||
print("answer argmax:", np.argmax(results))
|
||||
print("output argmax:", np.argmax(outputs))
|
||||
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||
getDiff(results, outputs)
|
||||
|
||||
outputs = run_model(model, runtime, world_size, rank, data_type=data_type)
|
||||
print("outputs abs mean:", abs(outputs).mean())
|
||||
print("max abs diff:", abs(outputs - results).max())
|
||||
|
||||
def start_worker(
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
|
@ -85,7 +96,7 @@ def start_worker(
|
|||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
)
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.BangRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
|
@ -93,13 +104,12 @@ def start_worker(
|
|||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_and_compare(name, model, runtime, world_size, rank)
|
||||
run_and_compare(name, model, runtime, world_size, rank, data_type)
|
||||
|
||||
|
||||
def start_single(name, model):
|
||||
def start_single(name, model, data_type):
|
||||
runtime = backend.BangRuntime(0)
|
||||
run_and_compare(name, model, runtime)
|
||||
|
||||
run_and_compare(name, model, runtime, data_type=data_type)
|
||||
|
||||
def generate_input_output(model):
|
||||
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||
|
@ -132,55 +142,36 @@ def generate_input_output(model):
|
|||
np.save(f"./data/output", output)
|
||||
|
||||
|
||||
def load_inputs(stub, world_size=1, rank=0):
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = np.load(f"./data/input_{i}.npy")
|
||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||
tensor.copyin_numpy(input)
|
||||
else:
|
||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||
|
||||
def getDiff(base, test):
|
||||
absolute_diff = np.abs(np.subtract(base, test))
|
||||
max_absolute_diff = np.max(absolute_diff)
|
||||
|
||||
baseCopy = base.astype(np.float64).ravel()
|
||||
testCopy = test.astype(np.float64).ravel()
|
||||
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||
max_relative_diff = upValue / downValue
|
||||
print(f"Max absolute difference: {max_absolute_diff}\n"
|
||||
f"Max relative difference: {max_relative_diff}")
|
||||
return max_absolute_diff, max_relative_diff
|
||||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, data_type = parse_args()
|
||||
data_type = "default" if data_type == "fp32" else data_type
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print("Generate inputs and outputs.")
|
||||
p = mp.Process(target=generate_input_output, args=[model])
|
||||
p.start()
|
||||
p.join()
|
||||
print(f"generate standard data for {name}.")
|
||||
# a small vocabulary size to fit all LLM.
|
||||
generate_input_output(model)
|
||||
return
|
||||
|
||||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
print("run model by single MLU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
if nproc_per_node == 1:
|
||||
# run single process.
|
||||
# use standalone process to isolate bang.
|
||||
print("run model by single MLU.")
|
||||
# p = mp.Process(target=start_single, args=(name, model, data_type))
|
||||
# p.start()
|
||||
# p.join()
|
||||
start_single(name, model, data_type)
|
||||
return
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"run model by {world_size} MLUs in parallel.")
|
||||
print(f"run model by {world_size} MLU in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model, data_type),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
|
@ -0,0 +1,249 @@
|
|||
import argparse
|
||||
import torch
|
||||
import torch_mlu
|
||||
from transformers import BertModel, BertConfig
|
||||
from transformers import GPT2Model, GPT2Config
|
||||
from transformers import OPTModel, OPTConfig
|
||||
from transformers import AlbertModel, AlbertConfig
|
||||
from transformers import LlamaModel, LlamaConfig
|
||||
import time
|
||||
import numpy as np
|
||||
import onnx
|
||||
import sys
|
||||
import os
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnxsim import simplify
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
||||
parser.add_argument(
|
||||
"--model", type=str, choices=["gpt2", "bert", "opt", "llama", "albert"], required=True, help="model type"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--export_onnx",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default=None,
|
||||
const="./",
|
||||
help="whether and where to export onnx file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--type", type=str, choices=["fp32", "fp16", "tf32"], required=True, help="model data type"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.export_onnx,
|
||||
args.type
|
||||
)
|
||||
|
||||
|
||||
def get_model(modelname):
|
||||
match modelname:
|
||||
case "albert":
|
||||
model = AlbertModel.from_pretrained("albert/albert-base-v2")
|
||||
voc_size = AlbertConfig().vocab_size
|
||||
case "bert":
|
||||
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
||||
voc_size = BertConfig().vocab_size
|
||||
case "gpt2":
|
||||
model = GPT2Model.from_pretrained("GPT2")
|
||||
voc_size = GPT2Config().vocab_size
|
||||
case "opt":
|
||||
model = OPTModel.from_pretrained("facebook/opt-125m")
|
||||
voc_size = OPTConfig().vocab_size
|
||||
case "llama":
|
||||
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
voc_size = LlamaConfig().vocab_size
|
||||
case _:
|
||||
raise KeyError(modelname)
|
||||
|
||||
model = model.eval()
|
||||
return model, voc_size
|
||||
|
||||
def run_pytorch(torch_model, voc_size, batchsize, len, dtype="fp32"):
|
||||
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
||||
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||
np.save("./data/input_0", data)
|
||||
inputs = torch.from_numpy(data).to("mlu")
|
||||
torch_model = torch_model.to("mlu")
|
||||
if dtype == "fp16":
|
||||
torch_model = torch_model.half()
|
||||
|
||||
n_iter = 20
|
||||
with torch.no_grad():
|
||||
for _ in range(10):
|
||||
outputs = torch_model(inputs)
|
||||
torch.mlu.synchronize()
|
||||
begin = time.time()
|
||||
with torch.no_grad():
|
||||
for _ in range(n_iter):
|
||||
torch.mlu.synchronize()
|
||||
outputs = torch_model(inputs)
|
||||
torch.mlu.synchronize()
|
||||
torch.mlu.synchronize()
|
||||
end = time.time()
|
||||
|
||||
avg_time = (end - begin) / n_iter
|
||||
outputs = outputs.last_hidden_state.to("cpu")
|
||||
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
||||
print(f"average time: {avg_time}")
|
||||
# torch.mlu.memory.empty_cache()
|
||||
np.save("./data/output", np.array(outputs))
|
||||
print("Save input & output into ./data.")
|
||||
|
||||
|
||||
def export_onnx(modelname, model, data, path, extern=False, dtype="fp32"):
|
||||
data = data.to("mlu")
|
||||
model = model.to("mlu")
|
||||
if dtype == "fp16":
|
||||
model = model.half()
|
||||
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
||||
if modelname != "llama":
|
||||
# use onnxsim to simplify
|
||||
onnx_model = onnx.load(path)
|
||||
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
|
||||
# onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
||||
assert check
|
||||
add_value_info_for_constants(onnx_model)
|
||||
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||
if extern:
|
||||
extern_path = path.replace('.onnx', '.pb')
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
extern_path = extern_path.split("/")[-1]
|
||||
convert_model_to_external_data(
|
||||
onnx_model,
|
||||
all_tensors_to_one_file=True,
|
||||
location=extern_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(onnx_model, path)
|
||||
else:
|
||||
# use third party tool to simplify llama
|
||||
# reference: https://github.com/luchangli03/onnxsim_large_model/
|
||||
sys.path.append("onnxsim_large_model")
|
||||
from onnx_utils import set_onnx_input_shape
|
||||
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
|
||||
|
||||
in_model_path = path
|
||||
out_model_path = path
|
||||
if not out_model_path:
|
||||
out_model_path = in_model_path[:-5] + ".sim.onnx"
|
||||
if os.path.isdir(out_model_path):
|
||||
out_model_path = os.path.join(out_model_path, os.path.basename(in_model_path))
|
||||
|
||||
onnx_model = onnx.load(in_model_path)
|
||||
print(f"load model from {in_model_path} success")
|
||||
|
||||
size_th_bytes = 1024 * 1024
|
||||
|
||||
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
|
||||
print(f"compress model success")
|
||||
|
||||
onnx_model = set_onnx_input_shape(onnx_model, "")
|
||||
|
||||
tensor_size_threshold = f"1024KB"
|
||||
skipped_optimizers = []
|
||||
skipped_optimizers.append("eliminate_duplicate_initializer")
|
||||
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
|
||||
tensor_size_threshold=tensor_size_threshold)
|
||||
if not check:
|
||||
raise ValueError(f"simplify compressed model {in_model_path} failed")
|
||||
|
||||
print(f"simplify model success")
|
||||
|
||||
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
|
||||
print(f"uncompress model success")
|
||||
|
||||
add_value_info_for_constants(onnx_model)
|
||||
|
||||
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
|
||||
|
||||
|
||||
def add_value_info_for_constants(model : onnx.ModelProto):
|
||||
"""
|
||||
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||
that info explicitly as ValueInfoProtos.
|
||||
Mutates the model.
|
||||
Args:
|
||||
model: The ModelProto to update.
|
||||
"""
|
||||
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
||||
if model.ir_version < 4:
|
||||
return
|
||||
|
||||
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||
inputs = {i.name for i in graph.input}
|
||||
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||
for init in graph.initializer:
|
||||
# Check it really is a constant, not an input
|
||||
if init.name in inputs:
|
||||
continue
|
||||
|
||||
# The details we want to add
|
||||
elem_type = init.data_type
|
||||
shape = init.dims
|
||||
|
||||
# Get existing or create new value info for this constant
|
||||
vi = existing_info.get(init.name)
|
||||
if vi is None:
|
||||
vi = graph.value_info.add()
|
||||
vi.name = init.name
|
||||
|
||||
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
||||
tt = vi.type.tensor_type
|
||||
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
||||
tt.elem_type = elem_type
|
||||
if not tt.HasField("shape"):
|
||||
# Ensure we set an empty list if the const is scalar (zero dims)
|
||||
tt.shape.dim.extend([])
|
||||
for dim in shape:
|
||||
tt.shape.dim.add().dim_value = dim
|
||||
|
||||
# Handle subgraphs
|
||||
for node in graph.node:
|
||||
for attr in node.attribute:
|
||||
# Ref attrs refer to other attrs, so we don't need to do anything
|
||||
if attr.ref_attr_name != "":
|
||||
continue
|
||||
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
add_const_value_infos_to_graph(attr.g)
|
||||
if attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for g in attr.graphs:
|
||||
add_const_value_infos_to_graph(g)
|
||||
|
||||
|
||||
return add_const_value_infos_to_graph(model.graph)
|
||||
|
||||
|
||||
def main():
|
||||
torch.backends.mlu.matmul.allow_tf32 = False
|
||||
torch.backends.cnnl.allow_tf32 = False
|
||||
modelname, batchsize, seqlen, export_path, dtype = parse_args()
|
||||
if dtype == "tf32":
|
||||
torch.backends.mlu.matmul.allow_tf32 = True
|
||||
else:
|
||||
os.environ["CAMBRICON_TF32_OVERRIDE"] = "0"
|
||||
|
||||
model, voc_size = get_model(modelname)
|
||||
if export_path is not None:
|
||||
filename = "{}_{}_{}_{}.onnx".format(modelname, batchsize, seqlen, dtype)
|
||||
path = os.path.join(export_path, filename)
|
||||
if not os.path.exists(path):
|
||||
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
||||
export_onnx(modelname, model, param, path, True, dtype)
|
||||
else:
|
||||
print("Onnx path exists, skipping export.")
|
||||
|
||||
run_pytorch(model, voc_size, batchsize, seqlen, dtype)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -10,9 +10,6 @@ import numpy as np
|
|||
from parallel_opt import parallel_model
|
||||
|
||||
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
|
@ -32,6 +29,9 @@ def parse_args():
|
|||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
|
@ -42,12 +42,13 @@ def parse_args():
|
|||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
args.type,
|
||||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, inputs, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||
def run_model(model, runtime, inputs, n=10, data_type = "default"):
|
||||
stub = OnnxStub(model, runtime, matmul_compute_type=data_type)
|
||||
for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
|
||||
tensor.copyin_numpy(input)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
|
@ -55,7 +56,7 @@ def run_model(model, runtime, inputs, n=10):
|
|||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
||||
for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
|
||||
tensor.copyin_numpy(input)
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
|
@ -66,17 +67,17 @@ def run_model(model, runtime, inputs, n=10):
|
|||
return outputs
|
||||
|
||||
|
||||
def run_and_compare(name, model, runtime):
|
||||
def run_and_compare(name, model, runtime, data_type):
|
||||
input_ids = np.load(f"{name}_inputs.npy")
|
||||
position_ids = np.arange(input_ids.shape[-1])
|
||||
results = np.load(f"{name}_results.npy")
|
||||
outputs = run_model(model, runtime, (input_ids, position_ids))
|
||||
outputs = run_model(model, runtime, (input_ids, position_ids), data_type=data_type)
|
||||
print("outputs abs mean:", abs(outputs).mean())
|
||||
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
|
||||
print("max abs diff:", abs(outputs - results).max())
|
||||
|
||||
|
||||
def start_worker(
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
|
@ -89,7 +90,7 @@ def start_worker(
|
|||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
)
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.CudaRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
|
@ -97,12 +98,12 @@ def start_worker(
|
|||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_and_compare(name, model, runtime)
|
||||
run_and_compare(name, model, runtime, data_type)
|
||||
|
||||
|
||||
def start_single(name, model):
|
||||
def start_single(name, model, data_type):
|
||||
runtime = backend.CudaRuntime(0)
|
||||
run_and_compare(name, model, runtime)
|
||||
run_and_compare(name, model, runtime, data_type)
|
||||
|
||||
|
||||
def gen_standard(name, model, voc_size, bs, len):
|
||||
|
@ -117,8 +118,10 @@ def gen_standard(name, model, voc_size, bs, len):
|
|||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
||||
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, data_type = parse_args()
|
||||
data_type = "default" if data_type == "fp32" else data_type
|
||||
if data_type != "tf32":
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# generate standart output
|
||||
|
@ -132,7 +135,7 @@ def main():
|
|||
# run single process.
|
||||
# use standalone process to isolate cuda.
|
||||
print("run model by single GPU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p = mp.Process(target=start_single, args=(name, model, data_type))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
|
@ -142,7 +145,7 @@ def main():
|
|||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model, data_type),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
|
@ -0,0 +1,188 @@
|
|||
import argparse
|
||||
import torch
|
||||
from transformers import BertModel, BertConfig
|
||||
from transformers import GPT2Model, GPT2Config
|
||||
from transformers import OPTModel, OPTConfig
|
||||
import time
|
||||
import numpy as np
|
||||
import onnx
|
||||
import os
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnxsim import simplify
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
||||
parser.add_argument(
|
||||
"--model", type=str, choices=["gpt2", "bert", "opt"], required=True, help="model type"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--export_onnx",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default=None,
|
||||
const="./",
|
||||
help="whether and where to export onnx file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.export_onnx,
|
||||
args.type,
|
||||
)
|
||||
|
||||
|
||||
def get_model(modelname):
|
||||
match modelname:
|
||||
case "bert":
|
||||
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
||||
voc_size = BertConfig().vocab_size
|
||||
case "gpt2":
|
||||
model = GPT2Model.from_pretrained("gpt2")
|
||||
voc_size = GPT2Config().vocab_size
|
||||
case "opt":
|
||||
model = model = OPTModel.from_pretrained("./opt-125m")
|
||||
voc_size = OPTConfig().vocab_size
|
||||
case _:
|
||||
raise KeyError(modelname)
|
||||
|
||||
model = model.eval()
|
||||
return model, voc_size
|
||||
|
||||
def run_pytorch(torch_model, voc_size, batchsize, len):
|
||||
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
||||
np.save("test_inputs", data)
|
||||
inputs = torch.from_numpy(data).to("cuda")
|
||||
torch_model = torch_model.to("cuda")
|
||||
|
||||
n_iter = 20
|
||||
with torch.no_grad():
|
||||
for _ in range(10):
|
||||
outputs = torch_model(inputs)
|
||||
torch.cuda.synchronize()
|
||||
begin = time.time()
|
||||
with torch.no_grad():
|
||||
for _ in range(n_iter):
|
||||
torch.cuda.synchronize()
|
||||
outputs = torch_model(inputs)
|
||||
#
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
avg_time = (end - begin) / n_iter
|
||||
outputs = outputs.last_hidden_state.to("cpu")
|
||||
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
||||
print(f"average time: {avg_time}")
|
||||
torch.cuda.memory.empty_cache()
|
||||
np.save("test_results", np.array(outputs, dtype=np.float32))
|
||||
print("Save input & output as test_inputs.npy and test_results.npy")
|
||||
|
||||
|
||||
def export_onnx(model, data, path, extern=False):
|
||||
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
||||
onnx_model = onnx.load(path)
|
||||
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
|
||||
#onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
||||
assert check
|
||||
add_value_info_for_constants(onnx_model)
|
||||
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||
if extern:
|
||||
extern_path = path.replace('.onnx', '.pb')
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
convert_model_to_external_data(
|
||||
onnx_model,
|
||||
all_tensors_to_one_file=True,
|
||||
location=extern_path,
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(onnx_model, path)
|
||||
|
||||
def add_value_info_for_constants(model : onnx.ModelProto):
|
||||
"""
|
||||
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||
that info explicitly as ValueInfoProtos.
|
||||
Mutates the model.
|
||||
Args:
|
||||
model: The ModelProto to update.
|
||||
"""
|
||||
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
||||
if model.ir_version < 4:
|
||||
return
|
||||
|
||||
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||
inputs = {i.name for i in graph.input}
|
||||
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||
for init in graph.initializer:
|
||||
# Check it really is a constant, not an input
|
||||
if init.name in inputs:
|
||||
continue
|
||||
|
||||
# The details we want to add
|
||||
elem_type = init.data_type
|
||||
shape = init.dims
|
||||
|
||||
# Get existing or create new value info for this constant
|
||||
vi = existing_info.get(init.name)
|
||||
if vi is None:
|
||||
vi = graph.value_info.add()
|
||||
vi.name = init.name
|
||||
|
||||
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
||||
tt = vi.type.tensor_type
|
||||
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
||||
tt.elem_type = elem_type
|
||||
if not tt.HasField("shape"):
|
||||
# Ensure we set an empty list if the const is scalar (zero dims)
|
||||
tt.shape.dim.extend([])
|
||||
for dim in shape:
|
||||
tt.shape.dim.add().dim_value = dim
|
||||
|
||||
# Handle subgraphs
|
||||
for node in graph.node:
|
||||
for attr in node.attribute:
|
||||
# Ref attrs refer to other attrs, so we don't need to do anything
|
||||
if attr.ref_attr_name != "":
|
||||
continue
|
||||
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
add_const_value_infos_to_graph(attr.g)
|
||||
if attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for g in attr.graphs:
|
||||
add_const_value_infos_to_graph(g)
|
||||
|
||||
|
||||
return add_const_value_infos_to_graph(model.graph)
|
||||
|
||||
|
||||
def main():
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
modelname, batchsize, seqlen, export_path, data_type = parse_args()
|
||||
if data_type == "tf32":
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||
|
||||
model, voc_size = get_model(modelname)
|
||||
if export_path is not None:
|
||||
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
|
||||
path = os.path.join(export_path, filename)
|
||||
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
||||
export_onnx(model, param, path, True)
|
||||
|
||||
if data_type == "fp16":
|
||||
model = model.half()
|
||||
run_pytorch(model, voc_size, batchsize, seqlen)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,14 @@
|
|||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
models=("bert" "gpt2" "llama")
|
||||
batch_size=(1 32)
|
||||
seq_len=(100 500)
|
||||
nproc=(1 2 4)
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
for bs in "${batch_size[@]}"; do
|
||||
for len in "${seq_len[@]}"; do
|
||||
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" --export_onnx ../models/"$model" --export_only
|
||||
done
|
||||
done
|
||||
done
|
|
@ -0,0 +1,280 @@
|
|||
import sys
|
||||
sys.path.append('../')
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, choices=["gpt2", "bert", "llama"], help="name of model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="", help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_single",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether run model with single process with standard inputs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
default="./",
|
||||
help="path to save model input data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result_dir",
|
||||
default="./",
|
||||
help="path to save model standard output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--internal_model_dir",
|
||||
default="./",
|
||||
help="path to save internal onnx model for parallel run"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# check path, mkdir if not exist
|
||||
check_exists(args.input_dir)
|
||||
check_exists(args.result_dir)
|
||||
check_exists(args.internal_model_dir)
|
||||
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.num_nodes,
|
||||
args.nproc_per_node,
|
||||
args.name,
|
||||
args.model,
|
||||
args.gen_std,
|
||||
args.run_single,
|
||||
args.input_dir,
|
||||
args.result_dir,
|
||||
args.internal_model_dir
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
utils function for this scripts
|
||||
"""
|
||||
def check_exists(path: str):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
def np_assert(base, test, rtol=1e-2, atol=1e-1):
|
||||
# np.testing.assert_allclose(test, base, rtol, atol)
|
||||
print("max abs diff:", abs(base - test).max())
|
||||
|
||||
|
||||
"""
|
||||
Perf wrapper, run function n times
|
||||
then average
|
||||
"""
|
||||
def perf_it(n):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# warmup
|
||||
for _ in range(n):
|
||||
func(*args, **kwargs)
|
||||
|
||||
t_total = 0
|
||||
for _ in range(n):
|
||||
t0 = time.time()
|
||||
func(*args, **kwargs)
|
||||
t1 = time.time()
|
||||
t_total += t1 - t0
|
||||
avg_time = (t_total) / n
|
||||
print(f"Avg runtime of {n} time is {avg_time:.6f} seconds")
|
||||
return avg_time
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
"""
|
||||
Run InfiniTensor model with Standard input
|
||||
check=True: check with standard output gen by pytorch
|
||||
perf=True: run n times to get avg time
|
||||
"""
|
||||
def run_model(task_name,
|
||||
model,
|
||||
runtime,
|
||||
world_size=1,
|
||||
rank=0,
|
||||
n=10,
|
||||
check=True,
|
||||
perf=True):
|
||||
|
||||
stub = OnnxStub(model, runtime,
|
||||
use_naive_allocator=True \
|
||||
if task_name == "llama" else False)
|
||||
|
||||
# load in Onnx model inputs
|
||||
def load_inputs(stub: OnnxStub):
|
||||
# check exists
|
||||
inputs = []
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input_path = os.path.join(input_dir, \
|
||||
f"{task_name}_input_{i}.npy")
|
||||
print(input_path)
|
||||
if os.path.exists(input_path):
|
||||
input = np.load(input_path)
|
||||
else :
|
||||
raise KeyError(f"{i} th input of model not exists")
|
||||
# check shape
|
||||
if all(x == y for x,y in zip(input.shape, tensor.shape())):
|
||||
tensor.copyin_numpy(input)
|
||||
else:
|
||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||
|
||||
load_inputs(stub)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
time.sleep(0.01)
|
||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# check output results with standard output
|
||||
if check:
|
||||
st_output_path = os.path.join(result_dir, \
|
||||
f"{task_name}_output.npy")
|
||||
assert os.path.exists(st_output_path) , \
|
||||
"standard output not exists"
|
||||
st_output = np.load(st_output_path)
|
||||
if np.isnan(output).any():
|
||||
print("Nan in output")
|
||||
exit()
|
||||
np_assert(st_output, output)
|
||||
|
||||
# perf
|
||||
if perf:
|
||||
@perf_it(n)
|
||||
def perf_infinitensor(stub: OnnxStub):
|
||||
stub.run()
|
||||
perf_infinitensor(stub)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
"""
|
||||
Start a worker in Parallel
|
||||
"""
|
||||
def start_worker(name: str,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
model: onnx.ModelProto):
|
||||
|
||||
dist_name = name + "_dist"
|
||||
# partial a onnx model to world_size part
|
||||
model = parallel_model(model, world_size, rank)
|
||||
onnx.save(model, os.path.join(internal_model_dir, \
|
||||
f"{dist_name}_rank{rank}.onnx"), save_as_external_data=True)
|
||||
runtime = backend.KUNLUNRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_model(name, model, runtime, world_size, rank)
|
||||
|
||||
|
||||
"""
|
||||
generate standard input/output with
|
||||
sigle card run
|
||||
"""
|
||||
def gen_standard(task_name: str, model: onnx.ModelProto):
|
||||
runtime = backend.KUNLUNRuntime(0)
|
||||
stub = OnnxStub(model, runtime)
|
||||
position_id = 0
|
||||
# generate random input for model
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = tensor.copyout_numpy()
|
||||
if np.issubdtype(input.dtype, np.integer):
|
||||
if input.size == 1:
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
else:
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
elif input.dtype == np.bool_:
|
||||
input = np.random.randint(0,2,size=input.shape) > 0
|
||||
else:
|
||||
if i == 0:
|
||||
input = np.ones(input.shape).astype(input.dtype)
|
||||
position_id = input.shape[-1] - 1
|
||||
else:
|
||||
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||
tensor.copyin_numpy(input)
|
||||
np.save(os.path.join(input_dir, \
|
||||
f"{task_name}_input_{i}.npy"), input)
|
||||
stub.run()
|
||||
# print(stub.outputs)
|
||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
if np.isnan(output).any():
|
||||
print("Nan in output")
|
||||
exit()
|
||||
np.save(os.path.join(result_dir, f"{task_name}_output.npy"), output)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
global input_dir, result_dir, internal_model_dir
|
||||
|
||||
nnodes, nproc_per_node, task_name, \
|
||||
model_path, gen_std, run_single, \
|
||||
input_dir, result_dir, internal_model_dir = parse_args()
|
||||
|
||||
# load input onnx model
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print("Generate inputs and outputs.")
|
||||
gen_standard(task_name, model)
|
||||
return
|
||||
|
||||
if run_single:
|
||||
print("Run model by one GPU card.")
|
||||
runtime = backend.KUNLUNRuntime(0)
|
||||
run_model(task_name, model, runtime)
|
||||
return
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"Run model by {world_size} GPU in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(task_name, world_size, rank, rank % nproc_per_node, model),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
|
||||
for w in workers:
|
||||
w.start()
|
||||
|
||||
for w in workers:
|
||||
w.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,36 @@
|
|||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# models=("bert" "gpt2" "llama")
|
||||
models=("bert" "gpt2")
|
||||
batch_size=(1 32)
|
||||
seq_len=(100 500)
|
||||
nproc=(1 2 4)
|
||||
|
||||
results_dir="results"
|
||||
|
||||
if [ -d "$results_dir" ]; then
|
||||
echo "directory ./$results_dir exists"
|
||||
else
|
||||
mkdir -p "$results_dir"
|
||||
echo "mkdir $results_dir, logs saved there"
|
||||
fi
|
||||
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
for bs in "${batch_size[@]}"; do
|
||||
for len in "${seq_len[@]}"; do
|
||||
# run pytorch model
|
||||
echo "Run pytorch $model with batch_size=$bs length=$len ."
|
||||
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len" #> results/"$model"_"$bs"_"$len"_pytorch
|
||||
for n in "${nproc[@]}"; do
|
||||
# run infinitensor
|
||||
echo "Run $n parallel infinitensor "$model" with batch_size=$bs and length=$len ."
|
||||
python kunlun_launch.py --name "$model" --model ../models/"$model"/"$model"_"$bs"_"$len".onnx --nproc_per_node=$n # >> results/"$model"_"$bs"_"$len"_infini
|
||||
# delete internal files
|
||||
find ./ -type f -name "*.onnx" -delete
|
||||
find ./ -type f -name "*.pb" -delete
|
||||
done
|
||||
find ./ -type f -name "*.npy" -delete
|
||||
done
|
||||
done
|
||||
done
|
|
@ -0,0 +1,35 @@
|
|||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# models=("bert" "gpt2" "llama")
|
||||
models=("llama")
|
||||
batch_size=(1 )
|
||||
seq_len=(100 500)
|
||||
nproc=(1 2 4)
|
||||
|
||||
results_dir="results"
|
||||
|
||||
if [ -d "$results_dir" ]; then
|
||||
echo "directory ./$results_dir exists"
|
||||
else
|
||||
mkdir -p "$results_dir"
|
||||
echo "mkdir $results_dir, logs saved there"
|
||||
fi
|
||||
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
for bs in "${batch_size[@]}"; do
|
||||
for len in "${seq_len[@]}"; do
|
||||
echo "Run pytorch llama with batch_size="$bs" and length="$len""
|
||||
python run_pytorch.py --model "$model" --batch_size "$bs" --length "$len"
|
||||
for n in "${nproc[@]}"; do
|
||||
# run pytorch model
|
||||
echo "Run infinitensor llama with batch_size="$bs" and length="$len" and nproc="$n"."
|
||||
python kunlun_launch.py --name llama --model ../models/llama/llama_"$bs"_"$len"_fp32.onnx --nproc_per_node=$n
|
||||
# delete internal files
|
||||
find ./ -type f -name "*.onnx" -delete
|
||||
find ./ -type f -name "*0c" -delete
|
||||
done
|
||||
find ./ -type f -name "*.npy" -delete
|
||||
done
|
||||
done
|
||||
done
|
|
@ -0,0 +1,245 @@
|
|||
import argparse
|
||||
import torch
|
||||
from transformers import BertModel, BertConfig
|
||||
from transformers import GPT2Model, GPT2Config
|
||||
from transformers import OPTModel, OPTConfig
|
||||
from transformers import LlamaModel, LlamaConfig
|
||||
import time
|
||||
import numpy as np
|
||||
import onnx
|
||||
import os
|
||||
import sys
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnxsim import simplify
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
||||
parser.add_argument(
|
||||
"--model", type=str, choices=["gpt2", "bert", "opt", "llama"], required=True, help="model type"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--export_onnx",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default=None,
|
||||
const="./",
|
||||
help="whether and where to export onnx file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="./",
|
||||
help="path to save pytorch model input data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result_dir",
|
||||
type=str,
|
||||
default="./",
|
||||
help="path to save pytorch model output data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_only",
|
||||
action="store_true"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.export_onnx,
|
||||
args.input_dir,
|
||||
args.result_dir,
|
||||
args.export_only
|
||||
)
|
||||
|
||||
|
||||
def get_model(modelname):
|
||||
if modelname == "bert":
|
||||
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
||||
voc_size = BertConfig().vocab_size
|
||||
elif modelname == "gpt2":
|
||||
model = GPT2Model.from_pretrained("gpt2")
|
||||
voc_size = GPT2Config().vocab_size
|
||||
elif modelname == "opt":
|
||||
model = OPTModel.from_pretrained("./opt-125m")
|
||||
voc_size = OPTConfig().vocab_size
|
||||
elif modelname == "llama":
|
||||
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
voc_size = LlamaConfig().vocab_size
|
||||
else :
|
||||
raise KeyError(modelname)
|
||||
|
||||
model = model.eval()
|
||||
return model, voc_size
|
||||
|
||||
def run_pytorch(torch_model, voc_size, batchsize, len, model_name):
|
||||
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
||||
np.save(os.path.join(input_dir, f"{model_name}_input_0.npy"), data)
|
||||
inputs = torch.from_numpy(data).to("cuda")
|
||||
torch_model = torch_model.to("cuda")
|
||||
|
||||
n_iter = 10
|
||||
with torch.no_grad():
|
||||
for _ in range(10):
|
||||
outputs = torch_model(inputs)
|
||||
torch.cuda.synchronize()
|
||||
begin = time.time()
|
||||
with torch.no_grad():
|
||||
for _ in range(n_iter):
|
||||
torch.cuda.synchronize()
|
||||
outputs = torch_model(inputs)
|
||||
#
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
avg_time = (end - begin) / n_iter
|
||||
outputs = outputs.last_hidden_state.to("cpu")
|
||||
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
||||
print(f"average time: {avg_time}")
|
||||
torch.cuda.memory.empty_cache()
|
||||
np.save(os.path.join(result_dir, f"{model_name}_output.npy"), \
|
||||
np.array(outputs))
|
||||
print(f"Save input & output as {model_name}_input_0.npy and {model_name}_output.npy")
|
||||
|
||||
|
||||
def export_onnx(model_name, model, data, path, extern=False):
|
||||
# torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
||||
|
||||
if model_name != "llama":
|
||||
onnx_model = onnx.load(path)
|
||||
onnx_model, check = simplify(onnx_model,
|
||||
skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
||||
# skipped_optimizers=['fuse_qkv'])
|
||||
assert check
|
||||
add_value_info_for_constants(onnx_model)
|
||||
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||
if extern:
|
||||
extern_path = path.replace('.onnx', '.pb')
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
convert_model_to_external_data(
|
||||
onnx_model,
|
||||
all_tensors_to_one_file=True,
|
||||
location=extern_path.split("/")[-1],
|
||||
size_threshold=1024,
|
||||
convert_attribute=False,
|
||||
)
|
||||
onnx.save(onnx_model, path)
|
||||
else:
|
||||
sys.path.append("onnxsim_large_model")
|
||||
from onnx_utils import set_onnx_input_shape
|
||||
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
|
||||
|
||||
in_model_path = path
|
||||
out_model_path = in_model_path[:-5] + ".sim.onnx"
|
||||
|
||||
onnx_model = onnx.load(in_model_path)
|
||||
print(f"load model from {in_model_path} success")
|
||||
|
||||
size_th_bytes = 1024 * 1024
|
||||
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
|
||||
print("compress model success")
|
||||
|
||||
onnx_model = set_onnx_input_shape(onnx_model, "")
|
||||
tensor_size_threshold = f"1024KB"
|
||||
skipped_optimizers = []
|
||||
skipped_optimizers.append("eliminate_duplicate_initializer")
|
||||
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
|
||||
tensor_size_threshold=tensor_size_threshold)
|
||||
if not check:
|
||||
raise ValueError(f"simplify compressed model {in_model_path} failed")
|
||||
|
||||
print(f"simplify model success")
|
||||
|
||||
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
|
||||
print(f"uncompress model success")
|
||||
|
||||
add_value_info_for_constants(onnx_model)
|
||||
|
||||
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
|
||||
|
||||
|
||||
def add_value_info_for_constants(model : onnx.ModelProto):
|
||||
"""
|
||||
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||
that info explicitly as ValueInfoProtos.
|
||||
Mutates the model.
|
||||
Args:
|
||||
model: The ModelProto to update.
|
||||
"""
|
||||
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
||||
if model.ir_version < 4:
|
||||
return
|
||||
|
||||
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||
inputs = {i.name for i in graph.input}
|
||||
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||
for init in graph.initializer:
|
||||
# Check it really is a constant, not an input
|
||||
if init.name in inputs:
|
||||
continue
|
||||
|
||||
# The details we want to add
|
||||
elem_type = init.data_type
|
||||
shape = init.dims
|
||||
|
||||
# Get existing or create new value info for this constant
|
||||
vi = existing_info.get(init.name)
|
||||
if vi is None:
|
||||
vi = graph.value_info.add()
|
||||
vi.name = init.name
|
||||
|
||||
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
||||
tt = vi.type.tensor_type
|
||||
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
||||
tt.elem_type = elem_type
|
||||
if not tt.HasField("shape"):
|
||||
# Ensure we set an empty list if the const is scalar (zero dims)
|
||||
tt.shape.dim.extend([])
|
||||
for dim in shape:
|
||||
tt.shape.dim.add().dim_value = dim
|
||||
|
||||
# Handle subgraphs
|
||||
for node in graph.node:
|
||||
for attr in node.attribute:
|
||||
# Ref attrs refer to other attrs, so we don't need to do anything
|
||||
if attr.ref_attr_name != "":
|
||||
continue
|
||||
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
add_const_value_infos_to_graph(attr.g)
|
||||
if attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for g in attr.graphs:
|
||||
add_const_value_infos_to_graph(g)
|
||||
|
||||
|
||||
return add_const_value_infos_to_graph(model.graph)
|
||||
|
||||
|
||||
def main():
|
||||
global input_dir, result_dir
|
||||
|
||||
modelname, batchsize, seqlen, \
|
||||
export_path, input_dir, result_dir, export_only = parse_args()
|
||||
|
||||
model, voc_size = get_model(modelname) # pytorch model
|
||||
|
||||
if export_path is not None:
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
|
||||
path = os.path.join(export_path, filename)
|
||||
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
||||
export_onnx(modelname, model, param, path, True) # export pytorch model to onnx model
|
||||
if export_only:
|
||||
return
|
||||
|
||||
run_pytorch(model, voc_size, batchsize, seqlen, modelname)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,213 +0,0 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from pyinfinitensor.onnx import OnnxStub, backend
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
import numpy as np
|
||||
from parallel_opt import parallel_model
|
||||
|
||||
st_input_dir = "standard/inputs/"
|
||||
st_output_dir = "standard/outputs/"
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||
parser.add_argument(
|
||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, default="test", help="name of this instance."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="/data1/shared/panzezhong/llama/fp32/my_llama_fp32.sim.onnx", help="path to the ONNX model file."
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||
parser.add_argument(
|
||||
"--gen_std",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to generate the standard results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run_single",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether run model with single process with standard inputs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("arg setting: ", args)
|
||||
return (
|
||||
args.num_nodes,
|
||||
args.nproc_per_node,
|
||||
args.name,
|
||||
args.model,
|
||||
args.batch_size,
|
||||
args.length,
|
||||
args.gen_std,
|
||||
args.run_single
|
||||
)
|
||||
|
||||
|
||||
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
||||
stub = OnnxStub(model, runtime)
|
||||
load_inputs(stub, world_size, rank)
|
||||
# stub.tune()
|
||||
stub.run()
|
||||
# get outputs
|
||||
time.sleep(0.01)
|
||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
|
||||
# bench
|
||||
begin = time.time()
|
||||
for _ in range(n):
|
||||
stub.run()
|
||||
end = time.time()
|
||||
avg_time = (end - begin) / n
|
||||
print(f"average time: {avg_time}")
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
||||
results = np.load(os.path.join(st_output_dir,f"output.npy"))
|
||||
outputs = run_model(model, runtime, world_size, rank)
|
||||
print(outputs[:100])
|
||||
if np.isnan(outputs).any():
|
||||
print("Nan in output")
|
||||
print("answer argmax:", np.argmax(results))
|
||||
print("output argmax:", np.argmax(outputs))
|
||||
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
||||
getDiff(results, outputs)
|
||||
|
||||
|
||||
def start_worker(
|
||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
||||
):
|
||||
dist_name = name + "_dist"
|
||||
model = parallel_model(model, world_size, rank)
|
||||
extern_path = f"./{dist_name}_rank{rank}.pb"
|
||||
if os.path.exists(extern_path):
|
||||
os.remove(extern_path)
|
||||
onnx.save_model(
|
||||
model,
|
||||
f"./{dist_name}_rank{rank}.onnx",
|
||||
save_as_external_data=True,
|
||||
location=extern_path,
|
||||
)
|
||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||
runtime = backend.KUNLUNRuntime(local_rank)
|
||||
# print("init comm")
|
||||
runtime.init_comm(
|
||||
dist_name,
|
||||
world_size,
|
||||
rank,
|
||||
)
|
||||
run_and_compare(name, model, runtime, world_size, rank)
|
||||
|
||||
|
||||
def start_single(name, model):
|
||||
runtime = backend.KUNLUNRuntime(0)
|
||||
run_and_compare(name, model, runtime)
|
||||
|
||||
|
||||
def generate_input_output(model):
|
||||
runtime = backend.KUNLUNRuntime(0)
|
||||
stub = OnnxStub(model, runtime)
|
||||
position_id = 0
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = tensor.copyout_numpy()
|
||||
if np.issubdtype(input.dtype, np.integer):
|
||||
if input.size == 1:
|
||||
# input = np.array([position_id])
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
else:
|
||||
input = np.random.randint(0,2,size=input.shape, dtype=input.dtype)
|
||||
elif input.dtype == np.bool_:
|
||||
input = np.random.randint(0,2,size=input.shape) > 0
|
||||
else:
|
||||
if i == 0:
|
||||
input = np.ones(input.shape).astype(input.dtype)
|
||||
position_id = input.shape[-1] - 1
|
||||
else:
|
||||
input = np.random.rand(*input.shape).astype(input.dtype)
|
||||
tensor.copyin_numpy(input)
|
||||
np.save(os.path.join(st_input_dir, f"input_{i}"), input)
|
||||
stub.run()
|
||||
# print(stub.outputs)
|
||||
time.sleep(0.01)
|
||||
output = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||
print(output[:100])
|
||||
if np.isnan(output).any():
|
||||
print("Nan in output")
|
||||
np.save(os.path.join(st_output_dir, f"output"), output)
|
||||
|
||||
|
||||
def load_inputs(stub, world_size=1, rank=0):
|
||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||
input = np.load(os.path.join(st_input_dir, f"input_{i}.npy"))
|
||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||
tensor.copyin_numpy(input)
|
||||
else:
|
||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||
|
||||
|
||||
def getDiff(base, test):
|
||||
absolute_diff = np.abs(np.subtract(base, test))
|
||||
max_absolute_diff = np.max(absolute_diff)
|
||||
|
||||
baseCopy = base.astype(np.float64).ravel()
|
||||
testCopy = test.astype(np.float64).ravel()
|
||||
upValue = np.sum(np.abs(baseCopy - testCopy))
|
||||
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
||||
max_relative_diff = upValue / downValue
|
||||
print(f"Max absolute difference: {max_absolute_diff}\nMax relative difference: {max_relative_diff}")
|
||||
|
||||
return max_absolute_diff, max_relative_diff
|
||||
|
||||
|
||||
def main():
|
||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, run_single = parse_args()
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# generate standart output
|
||||
if gen_std:
|
||||
print("Generate inputs and outputs.")
|
||||
p = mp.Process(target=generate_input_output, args=[model])
|
||||
p.start()
|
||||
p.join()
|
||||
return
|
||||
|
||||
# # run single process.
|
||||
# # use standalone process to isolate cuda.
|
||||
if run_single:
|
||||
print("run model by single GPU.")
|
||||
p = mp.Process(target=start_single, args=(name, model))
|
||||
p.start()
|
||||
p.join()
|
||||
return
|
||||
|
||||
# run distributed parallel.
|
||||
world_size = nnodes * nproc_per_node
|
||||
print(f"run model by {world_size} GPU in parallel.")
|
||||
workers = [
|
||||
mp.Process(
|
||||
target=start_worker,
|
||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
||||
)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
|
||||
for w in workers:
|
||||
w.start()
|
||||
|
||||
for w in workers:
|
||||
w.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
Subproject commit cbcf3fbf985a00494b0f136c92eaccd42031bf65
|
|
@ -110,7 +110,6 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
s_dim = 0
|
||||
elif in_plc.dim == 2:
|
||||
s_dim = 1
|
||||
|
||||
assert s_dim != -1
|
||||
assert out_dims[s_dim] % tp_world_size == 0, out_dims
|
||||
out_dims[s_dim] //= tp_world_size
|
||||
|
@ -244,5 +243,5 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
|||
if tt.HasField("shape"):
|
||||
tt.ClearField("shape")
|
||||
model = helper.make_model(graph)
|
||||
model = onnx.shape_inference.infer_shapes(model)
|
||||
#model = onnx.shape_inference.infer_shapes(model)
|
||||
return model
|
||||
|
|
|
@ -30,12 +30,14 @@ class GraphHandlerObj {
|
|||
int pw, int sh, int sw, int dh, int dw, int oph,
|
||||
int opw);
|
||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||
Tensor bias, ActType act);
|
||||
Tensor bias, ActType act,
|
||||
std::string matmul_compute_type = "default");
|
||||
Tensor batchNormalization(Tensor input, Tensor output, Tensor mean,
|
||||
Tensor var, Tensor scale, Tensor bias,
|
||||
float momentum, float eps, bool training);
|
||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias, float eps, int axis, int stash_type);
|
||||
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw, int ceilMode);
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
#include "utils/operator_utils.h"
|
||||
#include <functional>
|
||||
#include <nlohmann/json.hpp>
|
||||
using json = nlohmann::json;
|
||||
namespace infini {
|
||||
using json = nlohmann::json;
|
||||
|
||||
class RuntimeObj; // Forward declaration for Kernel::compute
|
||||
|
||||
|
|
|
@ -156,8 +156,9 @@ struct OpType {
|
|||
Resize,
|
||||
ReverseSequence,
|
||||
RoiAlign,
|
||||
RoPE, // Fusion
|
||||
Round, // Unary
|
||||
RoPE, // Fusion
|
||||
Round, // Unary
|
||||
RMSNorm, // Fusion
|
||||
STFT,
|
||||
Scan,
|
||||
Scatter,
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
using json = nlohmann::json;
|
||||
namespace infini {
|
||||
using json = nlohmann::json;
|
||||
|
||||
class PerfEngine {
|
||||
public:
|
||||
|
|
|
@ -13,4 +13,8 @@ void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
|||
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||
int c2, int c3);
|
||||
|
||||
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n);
|
||||
|
||||
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -7,4 +7,6 @@ void expandKernel(int dType, void *input, void *output, int nDims,
|
|||
int outputsize, SmallArray inputShape,
|
||||
SmallArray outputShape);
|
||||
|
||||
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
||||
int row_len);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/rms_norm.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||
int num_tokens, int hidden_size);
|
||||
|
||||
}; // namespace infini
|
|
@ -21,7 +21,7 @@ class KUNLUNRuntimeObj : public RuntimeObj {
|
|||
ctx = xdnn::create_context();
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
size_t workspaceSize = 3llu << 30; // 3 GB
|
||||
size_t workspaceSize = 2llu << 30; // 2 GB
|
||||
KUNLUNPtr wkspacePtr = alloc(workspaceSize);
|
||||
workspace =
|
||||
make_ref<WorkspaceObj<KUNLUNPtr>>(wkspacePtr, workspaceSize);
|
||||
|
@ -42,7 +42,7 @@ class KUNLUNRuntimeObj : public RuntimeObj {
|
|||
KUNLUNPtr alloc(size_t size) override {
|
||||
void *ptr;
|
||||
checkKUNLUNError(
|
||||
xpu_malloc_ex((void **)&ptr, size, XPUMemoryKind::XPU_MEM_MAIN));
|
||||
xpu_malloc((void **)&ptr, size, XPUMemoryKind::XPU_MEM_HBM));
|
||||
return ptr;
|
||||
}
|
||||
void dealloc(void *ptr) override { xpu_free(ptr); }
|
||||
|
|
|
@ -34,8 +34,8 @@ class XcclCommunicatorObj final : public CommunicatorObj {
|
|||
auto begin = std::chrono::steady_clock::now();
|
||||
while (!std::filesystem::exists(filePath)) {
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
_IT_ASSERT_2(now < begin + std::chrono::seconds(10),
|
||||
"time limit (10s) exceeded.");
|
||||
_IT_ASSERT_2(now < begin + std::chrono::seconds(100),
|
||||
"time limit (100s) exceeded.");
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
std::ifstream ifs(filePath, std::ios::binary);
|
||||
|
|
|
@ -17,6 +17,9 @@ class MatmulObj : public OperatorObj {
|
|||
// Auxiliary attributes which are not a part of operator attributes.
|
||||
int b, m, n, k;
|
||||
|
||||
// Specifies the data precision for the matrix multiply.
|
||||
std::string computeType = "default";
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Matmul operator with batch broadcast and tensor transpose
|
||||
|
@ -38,10 +41,11 @@ class MatmulObj : public OperatorObj {
|
|||
* @param transB If matrix B should be transposed when computing.
|
||||
* @param bias The bias tensor.
|
||||
* @param act The activation function.
|
||||
* @param computeType Specifies the data precision for the matrix multiply.
|
||||
*/
|
||||
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
|
||||
bool transA = false, bool transB = false, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
ActType act = ActType::None, std::string computeType = "default");
|
||||
OP_CLONE(MatmulObj);
|
||||
|
||||
std::string toString() const override;
|
||||
|
@ -60,6 +64,7 @@ class MatmulObj : public OperatorObj {
|
|||
int getN() const { return n; }
|
||||
int getK() const { return k; }
|
||||
auto getBMNK() const { return tuple{b, m, n, k}; }
|
||||
std::string getComputeType() const { return computeType; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Fused RMSNorm Operator
|
||||
*
|
||||
*/
|
||||
class RMSNormObj : public OperatorObj {
|
||||
int dim;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new RMSNorm object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
*/
|
||||
RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output);
|
||||
OP_CLONE(RMSNormObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
int getDim() const { return dim; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -23,12 +23,13 @@ from onnx.checker import (
|
|||
ValidationError,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.numpy_helper import to_array
|
||||
from onnx.numpy_helper import to_array, from_array
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||
from functools import reduce
|
||||
from onnxsim import simplify
|
||||
import copy
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OnnxStub:
|
||||
|
@ -37,7 +38,13 @@ class OnnxStub:
|
|||
It can be generated from an Onnx model object.
|
||||
"""
|
||||
|
||||
def __init__(self, model: ModelProto, runtime, use_naive_allocator: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
model: ModelProto,
|
||||
runtime,
|
||||
use_naive_allocator: bool = False,
|
||||
matmul_compute_type: str = "default",
|
||||
):
|
||||
# We use some user-defined operators for distributed inference
|
||||
try:
|
||||
# onnx simplifier performs inplace simplify
|
||||
|
@ -105,12 +112,6 @@ class OnnxStub:
|
|||
)
|
||||
tensors[input.name].set_input()
|
||||
|
||||
for output in model.graph.output:
|
||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
||||
tensors[output.name] = self.handler.tensor(
|
||||
dims, output.type.tensor_type.elem_type
|
||||
)
|
||||
tensors[output.name].set_output()
|
||||
|
||||
for node_idx in sorted_nodes:
|
||||
node = model.graph.node[node_idx]
|
||||
|
@ -215,6 +216,7 @@ class OnnxStub:
|
|||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
matmul_compute_type,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
|
@ -234,6 +236,7 @@ class OnnxStub:
|
|||
transB == 1,
|
||||
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||
backend.ActType.Linear,
|
||||
matmul_compute_type,
|
||||
)
|
||||
elif node.op_type == "BatchNormalization":
|
||||
(input, mean, var, scale, bias) = (
|
||||
|
@ -277,6 +280,12 @@ class OnnxStub:
|
|||
axis,
|
||||
stash_type,
|
||||
)
|
||||
elif node.op_type == "RMSNorm":
|
||||
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
@ -618,7 +627,7 @@ class OnnxStub:
|
|||
keep_aspect_ratio_policy,
|
||||
nearest_mode,
|
||||
coordinate_transformation_mode,
|
||||
)
|
||||
)
|
||||
elif node.op_type == "Squeeze":
|
||||
axes = (
|
||||
_parse_data(data[node.input[1]])
|
||||
|
@ -933,13 +942,32 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Where":
|
||||
## If Y is single -inf, treat Where as Add
|
||||
## TODO: deal with cases where Y is single inf or 0
|
||||
if node.input[0] in data and node.input[2] in data:
|
||||
where_condition = to_array(data[node.input[0]])
|
||||
where_alt = to_array(data[node.input[2]])
|
||||
if where_alt.size == 1:
|
||||
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
|
||||
node.input[0] = node.input[0] + "_alt"
|
||||
if node.input[0] not in data:
|
||||
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
|
||||
data[node.input[0]] = from_array(where_value, node.input[0])
|
||||
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
|
||||
tensors[node.input[0]].set_weight()
|
||||
tensors[node.output[0]] = self.handler.add(
|
||||
tensors[node.input[1]],
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
continue
|
||||
tensors[node.output[0]] = self.handler.where(
|
||||
tensors[node.input[1]],
|
||||
tensors[node.input[2]],
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Constant":
|
||||
elif node.op_type in ["Constant", "ConstantOfShape"]:
|
||||
output_name = node.output[0]
|
||||
attributes = _parse_attribute(node)
|
||||
tensor = attributes["value"]
|
||||
|
@ -962,10 +990,12 @@ class OnnxStub:
|
|||
beta,
|
||||
bias,
|
||||
size,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
for output in model.graph.output:
|
||||
tensors[output.name].set_output()
|
||||
################################
|
||||
# Allocate memory space for data
|
||||
################################
|
||||
|
@ -1247,7 +1277,7 @@ class OnnxStub:
|
|||
axes,
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Concat:
|
||||
axis = backend.concat_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/resize.h"
|
||||
#include "operators/rms_norm.h"
|
||||
#include "operators/rope.h"
|
||||
#include "operators/send.h"
|
||||
#include "operators/slice.h"
|
||||
|
@ -73,15 +74,17 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight,
|
|||
}
|
||||
|
||||
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||
bool transB, Tensor bias, ActType act) {
|
||||
bool transB, Tensor bias, ActType act,
|
||||
std::string matmul_compute_type) {
|
||||
if (y) {
|
||||
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
|
||||
transB, std::move(bias), act);
|
||||
transB, std::move(bias), act,
|
||||
matmul_compute_type);
|
||||
return y;
|
||||
} else {
|
||||
return g
|
||||
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
|
||||
std::move(bias), act)
|
||||
std::move(bias), act, matmul_compute_type)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
@ -122,6 +125,17 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
|
||||
output);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||
int ceilMode) {
|
||||
|
|
|
@ -506,6 +506,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||
.def("RMSNorm", &Handler::rmsNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
.def("add", &Handler::add, policy::move)
|
||||
|
|
|
@ -199,6 +199,24 @@ class CastCnnl : public BangKernelWithoutConfig {
|
|||
dim.data()));
|
||||
NlCastType = CNNL_CAST_UINT32_TO_INT64;
|
||||
break;
|
||||
case CastType::Float162Float:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_HALF, dim.size(),
|
||||
dim.data()));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, dim.size(),
|
||||
dim.data()));
|
||||
NlCastType = CNNL_CAST_HALF_TO_FLOAT;
|
||||
break;
|
||||
case CastType::Float2Float16:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, dim.size(),
|
||||
dim.data()));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_HALF, dim.size(),
|
||||
dim.data()));
|
||||
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
|
||||
break;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
|
|
@ -19,14 +19,16 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
|||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto inDims = op->getInputs(0)->getDims();
|
||||
auto fiterDims = op->getInputs(1)->getDims();
|
||||
auto outDims = op->getOutput()->getDims();
|
||||
auto fiterDims = op->getOutput(1)->getDims();
|
||||
|
||||
float eps = op->getEps();
|
||||
const int axis = op->getAxis();
|
||||
|
||||
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc;
|
||||
Shape outMeanDims(outDims);
|
||||
outMeanDims.erase(outMeanDims.begin() + axis);
|
||||
|
||||
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc, outMeanDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||
|
@ -39,15 +41,23 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
|||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||
outDims.size(), outDims.data()));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&outMeanDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
outMeanDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||
outMeanDims.size(), outMeanDims.data()));
|
||||
size_t wsSize;
|
||||
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
|
||||
&wsSize);
|
||||
BangPtr wsData = context->getWorkspace(wsSize);
|
||||
size_t meanSize =
|
||||
cnnlGetTensorElementNum(outMeanDesc) * op->getDType().getSize();
|
||||
BangPtr meanData = context->getWorkspace(meanSize);
|
||||
BangPtr rstdData = context->getWorkspace(meanSize);
|
||||
|
||||
cnnlStatus_t stat = cnnlLayerNormForward(
|
||||
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
|
||||
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
|
||||
inDesc, NULL, NULL);
|
||||
outMeanDesc, meanData, rstdData);
|
||||
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
|
|
@ -66,6 +66,13 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
|||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
|
||||
sizeof(int32_t));
|
||||
|
||||
std::string computeTypeStr = op->getComputeType();
|
||||
if (computeTypeStr == "tf32") {
|
||||
int32_t tf32 = 1;
|
||||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_ALLOW_TF32, &tf32,
|
||||
sizeof(int32_t));
|
||||
}
|
||||
|
||||
cnnlMatMulAlgo_t bmm_algo;
|
||||
cnnlMatMulAlgoCreate(&bmm_algo);
|
||||
|
||||
|
|
|
@ -115,6 +115,20 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
|||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
const int dType = _op->getDType().getIndex();
|
||||
|
||||
// Use optimized kernel if b is constant
|
||||
if (b_dim.size() == 0) {
|
||||
if (op->getOpType() == OpType::Div) {
|
||||
div_const_kernel(dType, aData, bData, cData,
|
||||
op->getOutput()->size());
|
||||
return;
|
||||
} else if (op->getOpType() == OpType::Pow) {
|
||||
pow_const_kernel(dType, aData, bData, cData,
|
||||
op->getOutput()->size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
|
||||
IT_TODO_HALT();
|
||||
|
@ -127,7 +141,6 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
|||
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||
|
||||
const int dType = _op->getDType().getIndex();
|
||||
if (op->getOpType() == OpType::Div) {
|
||||
div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
|
||||
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
|
|
|
@ -132,8 +132,8 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
|||
|
||||
#define CASE(OP, T) \
|
||||
_##OP##_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
|
||||
#define SWITCH_DTYPE(OP, DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
@ -177,7 +177,92 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
|||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void _div_const_kernel(void const *__restrict__ x,
|
||||
void const *__restrict__ y,
|
||||
void *__restrict__ z, const size_t n) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) {
|
||||
((T *)z)[tid] = ((T *)x)[tid] / *((T *)y);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void _pow_const_kernel(void const *__restrict__ x,
|
||||
void const *__restrict__ y,
|
||||
void *__restrict__ z, const size_t n) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) {
|
||||
((T *)z)[tid] = pow(((T *)x)[tid], *((T *)y));
|
||||
}
|
||||
}
|
||||
template <>
|
||||
__global__ void _pow_const_kernel<half>(void const *__restrict__ x,
|
||||
void const *__restrict__ y,
|
||||
void *__restrict__ z, const size_t n) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) {
|
||||
((half *)z)[tid] = pow(((float)((half *)x)[tid]), *((half *)y));
|
||||
}
|
||||
}
|
||||
|
||||
#define CASE_CONST(OP, T) \
|
||||
_##OP##_const_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(a, b, c, \
|
||||
n);
|
||||
|
||||
#define SWITCH_DTYPE_CONST(OP, DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE_CONST(OP, 1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE_CONST(OP, 2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE_CONST(OP, 3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE_CONST(OP, 4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE_CONST(OP, 5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE_CONST(OP, 6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE_CONST(OP, 7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE_CONST(OP, 10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE_CONST(OP, 11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE_CONST(OP, 12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE_CONST(OP, 13) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
|
||||
size_t blocksize = block_work_size();
|
||||
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
|
||||
SWITCH_DTYPE_CONST(div, dType);
|
||||
}
|
||||
|
||||
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
|
||||
size_t blocksize = block_work_size();
|
||||
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
|
||||
SWITCH_DTYPE_CONST(pow, dType);
|
||||
}
|
||||
|
||||
void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
|
@ -204,12 +289,12 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
|||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
if (dType == 1) {
|
||||
_pow_kernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 3) {
|
||||
_pow_kernel<int8_t>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 10) {
|
||||
int a_size = a0 * a1 * a2 * a3;
|
||||
int b_size = b0 * b1 * b2 * b3;
|
||||
|
@ -224,9 +309,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
|||
b_float[i] = __half2float(((half *)b)[i]);
|
||||
}
|
||||
_pow_kernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3,
|
||||
b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
for (int i = 0; i < c_size; ++i) {
|
||||
((half *)c)[i] = __float2half(c_float[i]);
|
||||
}
|
||||
|
|
|
@ -39,6 +39,14 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
|
|||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static __global__ void _expandRowKernel(void *__restrict__ dst,
|
||||
void const *__restrict__ src) {
|
||||
auto da = gridDim.x, db = blockDim.y, dx = blockDim.x, n = blockIdx.y,
|
||||
a = blockIdx.x, b = threadIdx.y, x = threadIdx.x;
|
||||
auto i = ((n * da + a) * db + b) * dx + x, j = (a * db + b) * dx + x;
|
||||
reinterpret_cast<T *>(dst)[i] = reinterpret_cast<T const *>(src)[j];
|
||||
}
|
||||
namespace infini {
|
||||
|
||||
#define CASE(T) \
|
||||
|
@ -96,4 +104,67 @@ void expandKernel(int dType, void *input, void *output, int nDims,
|
|||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
#define CASE_ROW(T) \
|
||||
_expandRowKernel<float> \
|
||||
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
|
||||
|
||||
#define SWITCH_DTYPE_ROW(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE_ROW(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE_ROW(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE_ROW(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE_ROW(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE_ROW(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE_ROW(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE_ROW(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE_ROW(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE_ROW(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE_ROW(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE_ROW(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE_ROW(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
// Optimization for expanding a row vector. The row length must be a multiple of 32
|
||||
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
||||
int row_len) {
|
||||
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
|
||||
// input: 1 x (a x b x 32 x sizeT)
|
||||
// output: n_rows x (a x b x 32 x sizeT)
|
||||
// grid: n_rows x a
|
||||
// block: b x 32
|
||||
auto c = row_len / 32, b = c;
|
||||
if (b > 32) {
|
||||
for (b = 32; c % b != 0; --b);
|
||||
}
|
||||
auto a = c / b;
|
||||
dim3 grid(a, n_rows), block(32, b);
|
||||
SWITCH_DTYPE_ROW(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -33,6 +33,36 @@ constexpr cublasGemmAlgo_t ALGOS[N_ALGO] = {
|
|||
CUBLAS_GEMM_ALGO18, CUBLAS_GEMM_ALGO19, CUBLAS_GEMM_ALGO20,
|
||||
CUBLAS_GEMM_ALGO21, CUBLAS_GEMM_ALGO22, CUBLAS_GEMM_ALGO23,
|
||||
};
|
||||
|
||||
cublasComputeType_t cuDataType2ComputeType(cudaDataType_t cuDataType) {
|
||||
if (cuDataType == CUDA_R_16F) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
} else if (cuDataType == CUDA_R_16BF) {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (cuDataType == CUDA_R_32F) {
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
} else if (cuDataType == CUDA_R_64F) {
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
cublasComputeType_t getCuComputeType(std::string computeTypeStr,
|
||||
cudaDataType_t cuDataType) {
|
||||
if (computeTypeStr == "tf32") {
|
||||
return CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
} else if (computeTypeStr == "bf16") {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
} else if (computeTypeStr == "fp16") {
|
||||
return CUBLAS_COMPUTE_32F_FAST_16F;
|
||||
} else if (computeTypeStr == "default") {
|
||||
return cuDataType2ComputeType(cuDataType);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
class matmulCublas : public Kernel {
|
||||
bool do_compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
|
@ -72,12 +102,25 @@ class matmulCublas : public Kernel {
|
|||
inputShape.data[i] = inC->getDims()[i - offset];
|
||||
}
|
||||
const int dType = dataType.getIndex();
|
||||
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
||||
out->getRawDataPtr<void *>(), nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
|
||||
// Bias in linear layer is row vector of (1,n), n is the number of
|
||||
// features. If row vector and n % 32 == 0, use optimized kernel.
|
||||
if (inC->getRank() == 1 && inC->getDims()[0] % 32 == 0) {
|
||||
expandRowKernel(dType, inC->getRawDataPtr<void *>(),
|
||||
out->getRawDataPtr<void *>(),
|
||||
out->size() / inC->getDims()[0],
|
||||
inC->getDims()[0]);
|
||||
} else {
|
||||
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
||||
out->getRawDataPtr<void *>(), nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
}
|
||||
// TODO:use compute type
|
||||
cublasStatus_t stat;
|
||||
std::string computeTypeStr = op->getComputeType();
|
||||
auto cuComputeType = getCuComputeType(computeTypeStr, cuDataType);
|
||||
|
||||
if (b > 1) {
|
||||
// Support batch broadcast with zero stride
|
||||
int dimA = op->getInputs(0)->getRank();
|
||||
|
@ -99,14 +142,14 @@ class matmulCublas : public Kernel {
|
|||
context->cublasHandle(), opB, opA, n, m, k, &alpha_half,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_half, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
cuComputeType, (cublasGemmAlgo_t)record->algo);
|
||||
|
||||
} else {
|
||||
stat = cublasGemmStridedBatchedEx(
|
||||
context->cublasHandle(), opB, opA, n, m, k, &alpha_naive,
|
||||
inBData, cuDataType, ldb, strideB, inAData, cuDataType, lda,
|
||||
strideA, &beta_naive, outData, cuDataType, ldc, m * n, b,
|
||||
cuDataType, (cublasGemmAlgo_t)record->algo);
|
||||
cuComputeType, (cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
} else {
|
||||
if (dataType == DataType::Float16) {
|
||||
|
@ -115,13 +158,13 @@ class matmulCublas : public Kernel {
|
|||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_half, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_half,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
outData, cuDataType, ldc, cuComputeType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
} else {
|
||||
stat = cublasGemmEx(context->cublasHandle(), opB, opA, n, m, k,
|
||||
&alpha_naive, inBData, cuDataType, ldb,
|
||||
inAData, cuDataType, lda, &beta_naive,
|
||||
outData, cuDataType, ldc, cuDataType,
|
||||
outData, cuDataType, ldc, cuComputeType,
|
||||
(cublasGemmAlgo_t)record->algo);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
#include "operators/rms_norm.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_rmsnorm.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class RMSNormCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<RMSNormObj>(_op);
|
||||
|
||||
auto input = op->getInputs(0);
|
||||
auto weight = op->getInputs(1);
|
||||
auto output = op->getOutput();
|
||||
void *const inputData = input->getRawDataPtr<void *>();
|
||||
void *const weightData = weight->getRawDataPtr<void *>();
|
||||
void *const outputData = output->getRawDataPtr<void *>();
|
||||
const auto &inputShape = input->getDims();
|
||||
int nDims = input->getDims().size();
|
||||
|
||||
int hidden_size = inputShape[nDims - 1];
|
||||
int num_tokens = input->size() / hidden_size;
|
||||
IT_ASSERT(hidden_size == (int)weight->size());
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens,
|
||||
hidden_size);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,112 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
template<class T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(uint32_t(-1), val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template<class T>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if(threadIdx.x == 0){
|
||||
s_variance = rsqrtf(variance / hidden_size + 0.00001f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define CASE(T) \
|
||||
_rmsnorm_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(input, weight, output, num_tokens, hidden_size);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||
int num_tokens, int hidden_size) {
|
||||
dim3 blocksize = dim3(std::min(hidden_size, 1024));
|
||||
dim3 gridsize = dim3(num_tokens);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
|||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
||||
int dim_model = inputShape[2];
|
||||
int dim_head = dim_model / 32;
|
||||
int dim_head = 128;
|
||||
int hidden_stride = dim_model * inputShape[1];
|
||||
int pos_stride = inputShape[1];
|
||||
|
||||
|
|
|
@ -3,11 +3,6 @@
|
|||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
|
||||
template <class T>
|
||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
||||
int dim_head, int hidden_stride, int pos_stride) {
|
||||
|
@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
namespace infini {
|
||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
dim3 blocksize = dim3(1024,1,1);
|
||||
dim3 gridsize = dim3(1, 1, 4);
|
||||
dim3 blocksize = dim3(32,1,1);
|
||||
dim3 gridsize = dim3(1, 1, dim_model/32);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
|
|
|
@ -315,6 +315,8 @@ void unary_kernel(const Operator &_op) {
|
|||
} else if (op->getOpType() == OpType::Silu) {
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
silu_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||
} else if (_op->getDType() == DataType::Float16){
|
||||
silu_kernel<half>((half *)inputData, (half *)outputData, num);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
|
|
@ -97,11 +97,14 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
|||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bSize = op->getInputs(1)->size();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
auto dtype = op->getDType();
|
||||
|
||||
// op input a, b is scalar while aDim and b Dim is empty
|
||||
if (bDim.size() == 0) {
|
||||
bDim.push_back(1);
|
||||
}
|
||||
if (aDim.size() == 0) {
|
||||
aDim.push_back(1);
|
||||
}
|
||||
|
||||
if (aSize == bSize) {
|
||||
// Do ElementWise Sub with no broadcast
|
||||
|
@ -109,23 +112,9 @@ class DivXdnn : public KUNLUNKernelWithoutConfig {
|
|||
(float *)aData, (float *)bData,
|
||||
(float *)cData, aSize));
|
||||
} else {
|
||||
// Do broadcast div
|
||||
Shape aligned = infer_broadcast(aDim, bDim);
|
||||
if (aligned == aDim) {
|
||||
// BData need to be broadcasted
|
||||
checkKUNLUNError(xdnn::broadcast_div<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim));
|
||||
} else {
|
||||
// Use workspace to broadcast aData
|
||||
KUNLUNPtr wks = context->getWorkspace(bSize * dtype.getSize());
|
||||
checkKUNLUNError(xdnn::broadcast<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)wks, aDim,
|
||||
bDim));
|
||||
checkKUNLUNError(xdnn::div<float>(context->KUNLUNHandle(),
|
||||
(float *)wks, (float *)bData,
|
||||
(float *)cData, bSize));
|
||||
}
|
||||
checkKUNLUNError(xdnn::broadcast_div<float>(
|
||||
context->KUNLUNHandle(), (float *)aData, (float *)bData,
|
||||
(float *)cData, aDim, bDim));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -570,6 +570,7 @@ REGISTER_KERNEL(Device::KUNLUN, OpType::Reciprocal, ReciprocalXdnn,
|
|||
REGISTER_KERNEL(Device::KUNLUN, OpType::Reshape, CopyXdnn, "Reshape_xdnn");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Flatten, CopyXdnn, "Flatten_xdnn");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Identity, CopyXdnn, "Identity_xdnn");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Squeeze, CopyXdnn, "Squeeze_xdnn");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Abs, AbsXdnn, "Abs_xdnn");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Atan, ATanXdnn, "Atan_xdnn");
|
||||
REGISTER_KERNEL(Device::KUNLUN, OpType::Log, LogXdnn, "Log_xdnn");
|
||||
|
|
|
@ -5,10 +5,11 @@
|
|||
namespace infini {
|
||||
|
||||
MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
||||
bool transB, [[maybe_unused]] Tensor bias, ActType act)
|
||||
bool transB, [[maybe_unused]] Tensor bias, ActType act,
|
||||
std::string computeType)
|
||||
: OperatorObj(OpType::MatMul,
|
||||
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
|
||||
transA(transA), transB(transB), act(act), b(1) {
|
||||
transA(transA), transB(transB), act(act), b(1), computeType(computeType) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -17,7 +18,8 @@ string MatmulObj::toString() const {
|
|||
os << "Matmul([" << (transA ? "A^T" : "A") << "," << (transB ? "B^T" : "B")
|
||||
<< ",act=" << enum_to_underlying(act) << "],A=" << inputs[0]->getGuid()
|
||||
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
|
||||
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])";
|
||||
<< ",bmnk=[" << b << "," << m << "," << n << "," << k << "])"
|
||||
<< ",computeType=" << computeType;
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
#include "operators/rms_norm.h"
|
||||
|
||||
namespace infini {
|
||||
RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
Tensor output)
|
||||
: OperatorObj(OpType::RMSNorm, {input, weight}, {output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> RMSNormObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
auto input_dim = A->getDims();
|
||||
auto output_dim = input_dim;
|
||||
return {{output_dim}};
|
||||
}
|
||||
|
||||
std::string RMSNormObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << type.toString() << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> RMSNormObj::getWorkloadVector() const {
|
||||
vector<int> ret{type.underlying()};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> RMSNormObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
}; // namespace infini
|
Loading…
Reference in New Issue