forked from jiuyuan/InfiniTensor
添加 MLU 平台分布式验收脚本 (#223)
* 添加 MLU 平台分布式验收脚本 * add fp16 test, fix cast * fix * add onnxsim for llama * add matmul tf32 for mlu * add submodule: onnxsim_large_model * fix * modified bang_launch.py, start_single * add test for albert/opt * change file path --------- Co-authored-by: xgqdut2016 <kenan_gewei@163.com>
This commit is contained in:
parent
985d0dee5f
commit
fac28c25f6
|
@ -13,3 +13,6 @@
|
||||||
[submodule "example"]
|
[submodule "example"]
|
||||||
path = examples/NNmodel
|
path = examples/NNmodel
|
||||||
url = git@github.com:wanghailu0717/NNmodel.git
|
url = git@github.com:wanghailu0717/NNmodel.git
|
||||||
|
[submodule "examples/distributed/onnxsim_large_model"]
|
||||||
|
path = examples/distributed/onnxsim_large_model
|
||||||
|
url = git@github.com:luchangli03/onnxsim_large_model.git
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
# 分布式脚本
|
# 分布式脚本
|
||||||
|
|
||||||
|
## 英伟达平台运行方式
|
||||||
|
|
||||||
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
||||||
|
|
||||||
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
||||||
|
@ -15,3 +17,23 @@ python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
|
||||||
```bash
|
```bash
|
||||||
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 寒武纪平台运行方式
|
||||||
|
|
||||||
|
**将上述运行脚本 `run_pytorch.py` 以及 `cuda_launch.py` 针对寒武纪平台做了相应的适配,具体见 `run_pytorch_mlu.py` 以及 `bang_launch.py`。**
|
||||||
|
|
||||||
|
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
||||||
|
|
||||||
|
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_pytorch_mlu.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
|
||||||
|
```
|
||||||
|
|
||||||
|
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
|
||||||
|
|
||||||
|
#### 2. 运行InfiniTensor分布式脚本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python bang_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
||||||
|
```
|
|
@ -1,35 +1,39 @@
|
||||||
|
import sys
|
||||||
|
sys.path.append('../')
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from pyinfinitensor.onnx import OnnxStub, backend
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
import onnx
|
import onnx
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
from onnx.shape_inference import infer_shapes_path
|
from onnx.shape_inference import infer_shapes_path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parallel_opt import parallel_model
|
from parallel_opt import parallel_model
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
|
||||||
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nproc_per_node", type=int, default=2, help="number of processes per node"
|
"--nproc_per_node", type=int, default=1, help="number of processes per node"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--name", type=str, default="test", help="name of this instance."
|
"--name", type=str, default="test", help="name of this instance."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx",
|
"--model", type=str, required=True, help="path to the ONNX model file."
|
||||||
help="path to the ONNX model file."
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen_std",
|
"--gen_std",
|
||||||
default=False,
|
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="whether to generate the standard results.",
|
help="whether to generate the standard results.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print("arg setting: ", args)
|
print("arg setting: ", args)
|
||||||
return (
|
return (
|
||||||
|
@ -40,39 +44,46 @@ def parse_args():
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
args.length,
|
args.length,
|
||||||
args.gen_std,
|
args.gen_std,
|
||||||
|
args.type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, runtime, world_size=1, rank=0, n=10):
|
def run_model(model, runtime, world_size=1, rank=0, n=10, data_type="default"):
|
||||||
stub = OnnxStub(model, runtime)
|
stub = OnnxStub(model, runtime, matmul_compute_type=data_type)
|
||||||
load_inputs(stub, world_size, rank)
|
load_inputs(stub, world_size, rank)
|
||||||
# stub.tune()
|
# stub.tune()
|
||||||
stub.run()
|
stub.run()
|
||||||
# get outputs
|
# get outputs
|
||||||
time.sleep(0.01)
|
|
||||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
begin = time.time()
|
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
stub.run()
|
stub.run()
|
||||||
|
begin = time.time()
|
||||||
|
for _ in range(n * 2):
|
||||||
|
stub.run()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
avg_time = (end - begin) / n
|
avg_time = (end - begin) / (n * 2)
|
||||||
print(f"average time: {avg_time}")
|
print(f"average time: {avg_time}")
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def load_inputs(stub, world_size=1, rank=0):
|
||||||
|
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
||||||
|
input = np.load(f"./data/input_{i}.npy")
|
||||||
|
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
||||||
|
tensor.copyin_numpy(input)
|
||||||
|
else:
|
||||||
|
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
||||||
|
|
||||||
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
|
|
||||||
|
def run_and_compare(name, model, runtime, world_size=1, rank=0, data_type="default"):
|
||||||
results = np.load(f"./data/output.npy")
|
results = np.load(f"./data/output.npy")
|
||||||
outputs = run_model(model, runtime, world_size, rank)
|
outputs = run_model(model, runtime, world_size, rank, data_type=data_type)
|
||||||
print("answer argmax:", np.argmax(results))
|
print("outputs abs mean:", abs(outputs).mean())
|
||||||
print("output argmax:", np.argmax(outputs))
|
print("max abs diff:", abs(outputs - results).max())
|
||||||
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
|
|
||||||
getDiff(results, outputs)
|
|
||||||
|
|
||||||
|
|
||||||
def start_worker(
|
def start_worker(
|
||||||
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto
|
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
|
||||||
):
|
):
|
||||||
dist_name = name + "_dist"
|
dist_name = name + "_dist"
|
||||||
model = parallel_model(model, world_size, rank)
|
model = parallel_model(model, world_size, rank)
|
||||||
|
@ -85,7 +96,7 @@ def start_worker(
|
||||||
save_as_external_data=True,
|
save_as_external_data=True,
|
||||||
location=extern_path,
|
location=extern_path,
|
||||||
)
|
)
|
||||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
runtime = backend.BangRuntime(local_rank)
|
runtime = backend.BangRuntime(local_rank)
|
||||||
# print("init comm")
|
# print("init comm")
|
||||||
runtime.init_comm(
|
runtime.init_comm(
|
||||||
|
@ -93,13 +104,12 @@ def start_worker(
|
||||||
world_size,
|
world_size,
|
||||||
rank,
|
rank,
|
||||||
)
|
)
|
||||||
run_and_compare(name, model, runtime, world_size, rank)
|
run_and_compare(name, model, runtime, world_size, rank, data_type)
|
||||||
|
|
||||||
|
|
||||||
def start_single(name, model):
|
def start_single(name, model, data_type):
|
||||||
runtime = backend.BangRuntime(0)
|
runtime = backend.BangRuntime(0)
|
||||||
run_and_compare(name, model, runtime)
|
run_and_compare(name, model, runtime, data_type=data_type)
|
||||||
|
|
||||||
|
|
||||||
def generate_input_output(model):
|
def generate_input_output(model):
|
||||||
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||||
|
@ -132,55 +142,36 @@ def generate_input_output(model):
|
||||||
np.save(f"./data/output", output)
|
np.save(f"./data/output", output)
|
||||||
|
|
||||||
|
|
||||||
def load_inputs(stub, world_size=1, rank=0):
|
|
||||||
for i, (name, tensor) in enumerate(stub.inputs.items()):
|
|
||||||
input = np.load(f"./data/input_{i}.npy")
|
|
||||||
if all(x == y for x,y in zip(input.shape,tensor.shape())):
|
|
||||||
tensor.copyin_numpy(input)
|
|
||||||
else:
|
|
||||||
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
|
|
||||||
|
|
||||||
def getDiff(base, test):
|
|
||||||
absolute_diff = np.abs(np.subtract(base, test))
|
|
||||||
max_absolute_diff = np.max(absolute_diff)
|
|
||||||
|
|
||||||
baseCopy = base.astype(np.float64).ravel()
|
|
||||||
testCopy = test.astype(np.float64).ravel()
|
|
||||||
upValue = np.sum(np.abs(baseCopy - testCopy))
|
|
||||||
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
|
|
||||||
max_relative_diff = upValue / downValue
|
|
||||||
print(f"Max absolute difference: {max_absolute_diff}\n"
|
|
||||||
f"Max relative difference: {max_relative_diff}")
|
|
||||||
return max_absolute_diff, max_relative_diff
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args()
|
nnodes, nproc_per_node, name, model_path, bs, length, gen_std, data_type = parse_args()
|
||||||
|
data_type = "default" if data_type == "fp32" else data_type
|
||||||
|
|
||||||
model = onnx.load(model_path)
|
model = onnx.load(model_path)
|
||||||
|
|
||||||
# generate standart output
|
# generate standart output
|
||||||
if gen_std:
|
if gen_std:
|
||||||
print("Generate inputs and outputs.")
|
print(f"generate standard data for {name}.")
|
||||||
p = mp.Process(target=generate_input_output, args=[model])
|
# a small vocabulary size to fit all LLM.
|
||||||
p.start()
|
generate_input_output(model)
|
||||||
p.join()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# run single process.
|
if nproc_per_node == 1:
|
||||||
# use standalone process to isolate cuda.
|
# run single process.
|
||||||
print("run model by single MLU.")
|
# use standalone process to isolate bang.
|
||||||
p = mp.Process(target=start_single, args=(name, model))
|
print("run model by single MLU.")
|
||||||
p.start()
|
# p = mp.Process(target=start_single, args=(name, model, data_type))
|
||||||
p.join()
|
# p.start()
|
||||||
|
# p.join()
|
||||||
|
start_single(name, model, data_type)
|
||||||
|
return
|
||||||
|
|
||||||
# run distributed parallel.
|
# run distributed parallel.
|
||||||
world_size = nnodes * nproc_per_node
|
world_size = nnodes * nproc_per_node
|
||||||
print(f"run model by {world_size} MLUs in parallel.")
|
print(f"run model by {world_size} MLU in parallel.")
|
||||||
workers = [
|
workers = [
|
||||||
mp.Process(
|
mp.Process(
|
||||||
target=start_worker,
|
target=start_worker,
|
||||||
args=(name, world_size, rank, rank % nproc_per_node, model),
|
args=(name, world_size, rank, rank % nproc_per_node, model, data_type),
|
||||||
)
|
)
|
||||||
for rank in range(world_size)
|
for rank in range(world_size)
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,249 @@
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch_mlu
|
||||||
|
from transformers import BertModel, BertConfig
|
||||||
|
from transformers import GPT2Model, GPT2Config
|
||||||
|
from transformers import OPTModel, OPTConfig
|
||||||
|
from transformers import AlbertModel, AlbertConfig
|
||||||
|
from transformers import LlamaModel, LlamaConfig
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnxsim import simplify
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, choices=["gpt2", "bert", "opt", "llama", "albert"], required=True, help="model type"
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--export_onnx",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default=None,
|
||||||
|
const="./",
|
||||||
|
help="whether and where to export onnx file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--type", type=str, choices=["fp32", "fp16", "tf32"], required=True, help="model data type"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.export_onnx,
|
||||||
|
args.type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(modelname):
|
||||||
|
match modelname:
|
||||||
|
case "albert":
|
||||||
|
model = AlbertModel.from_pretrained("albert/albert-base-v2")
|
||||||
|
voc_size = AlbertConfig().vocab_size
|
||||||
|
case "bert":
|
||||||
|
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
||||||
|
voc_size = BertConfig().vocab_size
|
||||||
|
case "gpt2":
|
||||||
|
model = GPT2Model.from_pretrained("GPT2")
|
||||||
|
voc_size = GPT2Config().vocab_size
|
||||||
|
case "opt":
|
||||||
|
model = OPTModel.from_pretrained("facebook/opt-125m")
|
||||||
|
voc_size = OPTConfig().vocab_size
|
||||||
|
case "llama":
|
||||||
|
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||||
|
voc_size = LlamaConfig().vocab_size
|
||||||
|
case _:
|
||||||
|
raise KeyError(modelname)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
return model, voc_size
|
||||||
|
|
||||||
|
def run_pytorch(torch_model, voc_size, batchsize, len, dtype="fp32"):
|
||||||
|
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
||||||
|
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
|
||||||
|
np.save("./data/input_0", data)
|
||||||
|
inputs = torch.from_numpy(data).to("mlu")
|
||||||
|
torch_model = torch_model.to("mlu")
|
||||||
|
if dtype == "fp16":
|
||||||
|
torch_model = torch_model.half()
|
||||||
|
|
||||||
|
n_iter = 20
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(10):
|
||||||
|
outputs = torch_model(inputs)
|
||||||
|
torch.mlu.synchronize()
|
||||||
|
begin = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(n_iter):
|
||||||
|
torch.mlu.synchronize()
|
||||||
|
outputs = torch_model(inputs)
|
||||||
|
torch.mlu.synchronize()
|
||||||
|
torch.mlu.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
avg_time = (end - begin) / n_iter
|
||||||
|
outputs = outputs.last_hidden_state.to("cpu")
|
||||||
|
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
# torch.mlu.memory.empty_cache()
|
||||||
|
np.save("./data/output", np.array(outputs))
|
||||||
|
print("Save input & output into ./data.")
|
||||||
|
|
||||||
|
|
||||||
|
def export_onnx(modelname, model, data, path, extern=False, dtype="fp32"):
|
||||||
|
data = data.to("mlu")
|
||||||
|
model = model.to("mlu")
|
||||||
|
if dtype == "fp16":
|
||||||
|
model = model.half()
|
||||||
|
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
||||||
|
if modelname != "llama":
|
||||||
|
# use onnxsim to simplify
|
||||||
|
onnx_model = onnx.load(path)
|
||||||
|
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
|
||||||
|
# onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
||||||
|
assert check
|
||||||
|
add_value_info_for_constants(onnx_model)
|
||||||
|
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||||
|
if extern:
|
||||||
|
extern_path = path.replace('.onnx', '.pb')
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
extern_path = extern_path.split("/")[-1]
|
||||||
|
convert_model_to_external_data(
|
||||||
|
onnx_model,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=extern_path,
|
||||||
|
size_threshold=1024,
|
||||||
|
convert_attribute=False,
|
||||||
|
)
|
||||||
|
onnx.save(onnx_model, path)
|
||||||
|
else:
|
||||||
|
# use third party tool to simplify llama
|
||||||
|
# reference: https://github.com/luchangli03/onnxsim_large_model/
|
||||||
|
sys.path.append("onnxsim_large_model")
|
||||||
|
from onnx_utils import set_onnx_input_shape
|
||||||
|
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
|
||||||
|
|
||||||
|
in_model_path = path
|
||||||
|
out_model_path = path
|
||||||
|
if not out_model_path:
|
||||||
|
out_model_path = in_model_path[:-5] + ".sim.onnx"
|
||||||
|
if os.path.isdir(out_model_path):
|
||||||
|
out_model_path = os.path.join(out_model_path, os.path.basename(in_model_path))
|
||||||
|
|
||||||
|
onnx_model = onnx.load(in_model_path)
|
||||||
|
print(f"load model from {in_model_path} success")
|
||||||
|
|
||||||
|
size_th_bytes = 1024 * 1024
|
||||||
|
|
||||||
|
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
|
||||||
|
print(f"compress model success")
|
||||||
|
|
||||||
|
onnx_model = set_onnx_input_shape(onnx_model, "")
|
||||||
|
|
||||||
|
tensor_size_threshold = f"1024KB"
|
||||||
|
skipped_optimizers = []
|
||||||
|
skipped_optimizers.append("eliminate_duplicate_initializer")
|
||||||
|
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
|
||||||
|
tensor_size_threshold=tensor_size_threshold)
|
||||||
|
if not check:
|
||||||
|
raise ValueError(f"simplify compressed model {in_model_path} failed")
|
||||||
|
|
||||||
|
print(f"simplify model success")
|
||||||
|
|
||||||
|
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
|
||||||
|
print(f"uncompress model success")
|
||||||
|
|
||||||
|
add_value_info_for_constants(onnx_model)
|
||||||
|
|
||||||
|
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
|
||||||
|
|
||||||
|
|
||||||
|
def add_value_info_for_constants(model : onnx.ModelProto):
|
||||||
|
"""
|
||||||
|
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||||
|
that info explicitly as ValueInfoProtos.
|
||||||
|
Mutates the model.
|
||||||
|
Args:
|
||||||
|
model: The ModelProto to update.
|
||||||
|
"""
|
||||||
|
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
||||||
|
if model.ir_version < 4:
|
||||||
|
return
|
||||||
|
|
||||||
|
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||||
|
inputs = {i.name for i in graph.input}
|
||||||
|
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||||
|
for init in graph.initializer:
|
||||||
|
# Check it really is a constant, not an input
|
||||||
|
if init.name in inputs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The details we want to add
|
||||||
|
elem_type = init.data_type
|
||||||
|
shape = init.dims
|
||||||
|
|
||||||
|
# Get existing or create new value info for this constant
|
||||||
|
vi = existing_info.get(init.name)
|
||||||
|
if vi is None:
|
||||||
|
vi = graph.value_info.add()
|
||||||
|
vi.name = init.name
|
||||||
|
|
||||||
|
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
||||||
|
tt = vi.type.tensor_type
|
||||||
|
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
||||||
|
tt.elem_type = elem_type
|
||||||
|
if not tt.HasField("shape"):
|
||||||
|
# Ensure we set an empty list if the const is scalar (zero dims)
|
||||||
|
tt.shape.dim.extend([])
|
||||||
|
for dim in shape:
|
||||||
|
tt.shape.dim.add().dim_value = dim
|
||||||
|
|
||||||
|
# Handle subgraphs
|
||||||
|
for node in graph.node:
|
||||||
|
for attr in node.attribute:
|
||||||
|
# Ref attrs refer to other attrs, so we don't need to do anything
|
||||||
|
if attr.ref_attr_name != "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if attr.type == onnx.AttributeProto.GRAPH:
|
||||||
|
add_const_value_infos_to_graph(attr.g)
|
||||||
|
if attr.type == onnx.AttributeProto.GRAPHS:
|
||||||
|
for g in attr.graphs:
|
||||||
|
add_const_value_infos_to_graph(g)
|
||||||
|
|
||||||
|
|
||||||
|
return add_const_value_infos_to_graph(model.graph)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.backends.mlu.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cnnl.allow_tf32 = False
|
||||||
|
modelname, batchsize, seqlen, export_path, dtype = parse_args()
|
||||||
|
if dtype == "tf32":
|
||||||
|
torch.backends.mlu.matmul.allow_tf32 = True
|
||||||
|
else:
|
||||||
|
os.environ["CAMBRICON_TF32_OVERRIDE"] = "0"
|
||||||
|
|
||||||
|
model, voc_size = get_model(modelname)
|
||||||
|
if export_path is not None:
|
||||||
|
filename = "{}_{}_{}_{}.onnx".format(modelname, batchsize, seqlen, dtype)
|
||||||
|
path = os.path.join(export_path, filename)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
||||||
|
export_onnx(modelname, model, param, path, True, dtype)
|
||||||
|
else:
|
||||||
|
print("Onnx path exists, skipping export.")
|
||||||
|
|
||||||
|
run_pytorch(model, voc_size, batchsize, seqlen, dtype)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit cbcf3fbf985a00494b0f136c92eaccd42031bf65
|
|
@ -199,6 +199,24 @@ class CastCnnl : public BangKernelWithoutConfig {
|
||||||
dim.data()));
|
dim.data()));
|
||||||
NlCastType = CNNL_CAST_UINT32_TO_INT64;
|
NlCastType = CNNL_CAST_UINT32_TO_INT64;
|
||||||
break;
|
break;
|
||||||
|
case CastType::Float162Float:
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_HALF, dim.size(),
|
||||||
|
dim.data()));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_FLOAT, dim.size(),
|
||||||
|
dim.data()));
|
||||||
|
NlCastType = CNNL_CAST_HALF_TO_FLOAT;
|
||||||
|
break;
|
||||||
|
case CastType::Float2Float16:
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_FLOAT, dim.size(),
|
||||||
|
dim.data()));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||||
|
CNNL_DTYPE_HALF, dim.size(),
|
||||||
|
dim.data()));
|
||||||
|
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,14 +19,16 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
||||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
auto inDims = op->getInputs(0)->getDims();
|
auto inDims = op->getInputs(0)->getDims();
|
||||||
|
auto fiterDims = op->getInputs(1)->getDims();
|
||||||
auto outDims = op->getOutput()->getDims();
|
auto outDims = op->getOutput()->getDims();
|
||||||
auto fiterDims = op->getOutput(1)->getDims();
|
|
||||||
|
|
||||||
float eps = op->getEps();
|
float eps = op->getEps();
|
||||||
const int axis = op->getAxis();
|
const int axis = op->getAxis();
|
||||||
|
|
||||||
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc;
|
Shape outMeanDims(outDims);
|
||||||
|
outMeanDims.erase(outMeanDims.begin() + axis);
|
||||||
|
|
||||||
|
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc, outMeanDesc;
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
@ -39,15 +41,23 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||||
outDims.size(), outDims.data()));
|
outDims.size(), outDims.data()));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&outMeanDesc));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
outMeanDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
|
||||||
|
outMeanDims.size(), outMeanDims.data()));
|
||||||
size_t wsSize;
|
size_t wsSize;
|
||||||
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
|
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
|
||||||
&wsSize);
|
&wsSize);
|
||||||
BangPtr wsData = context->getWorkspace(wsSize);
|
BangPtr wsData = context->getWorkspace(wsSize);
|
||||||
|
size_t meanSize =
|
||||||
|
cnnlGetTensorElementNum(outMeanDesc) * op->getDType().getSize();
|
||||||
|
BangPtr meanData = context->getWorkspace(meanSize);
|
||||||
|
BangPtr rstdData = context->getWorkspace(meanSize);
|
||||||
|
|
||||||
cnnlStatus_t stat = cnnlLayerNormForward(
|
cnnlStatus_t stat = cnnlLayerNormForward(
|
||||||
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
|
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
|
||||||
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
|
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
|
||||||
inDesc, NULL, NULL);
|
outMeanDesc, meanData, rstdData);
|
||||||
|
|
||||||
if (stat != CNNL_STATUS_SUCCESS)
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -66,6 +66,13 @@ class MatmulCnnl : public BangKernelWithoutConfig {
|
||||||
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
|
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
|
||||||
sizeof(int32_t));
|
sizeof(int32_t));
|
||||||
|
|
||||||
|
std::string computeTypeStr = op->getComputeType();
|
||||||
|
if (computeTypeStr == "tf32") {
|
||||||
|
int32_t tf32 = 1;
|
||||||
|
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_ALLOW_TF32, &tf32,
|
||||||
|
sizeof(int32_t));
|
||||||
|
}
|
||||||
|
|
||||||
cnnlMatMulAlgo_t bmm_algo;
|
cnnlMatMulAlgo_t bmm_algo;
|
||||||
cnnlMatMulAlgoCreate(&bmm_algo);
|
cnnlMatMulAlgoCreate(&bmm_algo);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue