添加 MLU 平台分布式验收脚本 (#223)

* 添加 MLU 平台分布式验收脚本

* add fp16 test, fix cast

* fix

* add onnxsim for llama

* add matmul tf32 for mlu

* add submodule: onnxsim_large_model

* fix

* modified bang_launch.py, start_single

* add test for albert/opt

* change file path

---------

Co-authored-by: xgqdut2016 <kenan_gewei@163.com>
This commit is contained in:
Bolun Zhang 2024-04-28 11:24:09 +08:00 committed by GitHub
parent 985d0dee5f
commit fac28c25f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 363 additions and 62 deletions

3
.gitmodules vendored
View File

@ -13,3 +13,6 @@
[submodule "example"] [submodule "example"]
path = examples/NNmodel path = examples/NNmodel
url = git@github.com:wanghailu0717/NNmodel.git url = git@github.com:wanghailu0717/NNmodel.git
[submodule "examples/distributed/onnxsim_large_model"]
path = examples/distributed/onnxsim_large_model
url = git@github.com:luchangli03/onnxsim_large_model.git

View File

@ -1,5 +1,7 @@
# 分布式脚本 # 分布式脚本
## 英伟达平台运行方式
#### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx #### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx
使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。 使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。
@ -15,3 +17,23 @@ python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
```bash ```bash
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4 python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
``` ```
## 寒武纪平台运行方式
**将上述运行脚本 `run_pytorch.py` 以及 `cuda_launch.py` 针对寒武纪平台做了相应的适配,具体见 `run_pytorch_mlu.py` 以及 `bang_launch.py`。**
#### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx
使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。
```bash
python run_pytorch_mlu.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
```
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
#### 2. 运行InfiniTensor分布式脚本
```bash
python bang_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
```

View File

@ -1,35 +1,39 @@
import sys
sys.path.append('../')
import argparse import argparse
import os import os
import time import time
import multiprocessing as mp import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend from pyinfinitensor.onnx import OnnxStub, backend
import onnx import onnx
from onnx.external_data_helper import convert_model_to_external_data
from onnx.shape_inference import infer_shapes_path from onnx.shape_inference import infer_shapes_path
import numpy as np import numpy as np
from parallel_opt import parallel_model from parallel_opt import parallel_model
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor") parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument( parser.add_argument(
"--nproc_per_node", type=int, default=2, help="number of processes per node" "--nproc_per_node", type=int, default=1, help="number of processes per node"
) )
parser.add_argument( parser.add_argument(
"--name", type=str, default="test", help="name of this instance." "--name", type=str, default="test", help="name of this instance."
) )
parser.add_argument( parser.add_argument(
"--model", type=str, default="/data/onnx_models/llama2/llama_bs1_seq1024.onnx", "--model", type=str, required=True, help="path to the ONNX model file."
help="path to the ONNX model file."
) )
parser.add_argument("--batch_size", type=int, default=1, help="batch size.") parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.") parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument( parser.add_argument(
"--gen_std", "--gen_std",
default=False,
action="store_true", action="store_true",
help="whether to generate the standard results.", help="whether to generate the standard results.",
) )
parser.add_argument(
"--type", type=str, choices=["fp32", "fp16", "tf32"], default="fp32", help="data type"
)
args = parser.parse_args() args = parser.parse_args()
print("arg setting: ", args) print("arg setting: ", args)
return ( return (
@ -40,39 +44,46 @@ def parse_args():
args.batch_size, args.batch_size,
args.length, args.length,
args.gen_std, args.gen_std,
args.type,
) )
def run_model(model, runtime, world_size=1, rank=0, n=10): def run_model(model, runtime, world_size=1, rank=0, n=10, data_type="default"):
stub = OnnxStub(model, runtime) stub = OnnxStub(model, runtime, matmul_compute_type=data_type)
load_inputs(stub, world_size, rank) load_inputs(stub, world_size, rank)
# stub.tune() # stub.tune()
stub.run() stub.run()
# get outputs # get outputs
time.sleep(0.01)
outputs = next(stub.outputs.values().__iter__()).copyout_numpy() outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench # bench
begin = time.time()
for _ in range(n): for _ in range(n):
stub.run() stub.run()
begin = time.time()
for _ in range(n * 2):
stub.run()
end = time.time() end = time.time()
avg_time = (end - begin) / n avg_time = (end - begin) / (n * 2)
print(f"average time: {avg_time}") print(f"average time: {avg_time}")
return outputs return outputs
def load_inputs(stub, world_size=1, rank=0):
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = np.load(f"./data/input_{i}.npy")
if all(x == y for x,y in zip(input.shape,tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
def run_and_compare(name, model, runtime, world_size=1, rank = 0):
def run_and_compare(name, model, runtime, world_size=1, rank=0, data_type="default"):
results = np.load(f"./data/output.npy") results = np.load(f"./data/output.npy")
outputs = run_model(model, runtime, world_size, rank) outputs = run_model(model, runtime, world_size, rank, data_type=data_type)
print("answer argmax:", np.argmax(results)) print("outputs abs mean:", abs(outputs).mean())
print("output argmax:", np.argmax(outputs)) print("max abs diff:", abs(outputs - results).max())
#np.testing.assert_allclose(outputs, results, rtol=1e-3, atol=1e-3)
getDiff(results, outputs)
def start_worker( def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto, data_type: str
): ):
dist_name = name + "_dist" dist_name = name + "_dist"
model = parallel_model(model, world_size, rank) model = parallel_model(model, world_size, rank)
@ -85,7 +96,7 @@ def start_worker(
save_as_external_data=True, save_as_external_data=True,
location=extern_path, location=extern_path,
) )
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") #infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.BangRuntime(local_rank) runtime = backend.BangRuntime(local_rank)
# print("init comm") # print("init comm")
runtime.init_comm( runtime.init_comm(
@ -93,13 +104,12 @@ def start_worker(
world_size, world_size,
rank, rank,
) )
run_and_compare(name, model, runtime, world_size, rank) run_and_compare(name, model, runtime, world_size, rank, data_type)
def start_single(name, model): def start_single(name, model, data_type):
runtime = backend.BangRuntime(0) runtime = backend.BangRuntime(0)
run_and_compare(name, model, runtime) run_and_compare(name, model, runtime, data_type=data_type)
def generate_input_output(model): def generate_input_output(model):
os.makedirs(os.path.dirname("./data/"), exist_ok=True) os.makedirs(os.path.dirname("./data/"), exist_ok=True)
@ -132,55 +142,36 @@ def generate_input_output(model):
np.save(f"./data/output", output) np.save(f"./data/output", output)
def load_inputs(stub, world_size=1, rank=0):
for i, (name, tensor) in enumerate(stub.inputs.items()):
input = np.load(f"./data/input_{i}.npy")
if all(x == y for x,y in zip(input.shape,tensor.shape())):
tensor.copyin_numpy(input)
else:
tensor.copyin_numpy(np.hsplit(input, world_size)[rank])
def getDiff(base, test):
absolute_diff = np.abs(np.subtract(base, test))
max_absolute_diff = np.max(absolute_diff)
baseCopy = base.astype(np.float64).ravel()
testCopy = test.astype(np.float64).ravel()
upValue = np.sum(np.abs(baseCopy - testCopy))
downValue = np.sum(np.abs(baseCopy)) + np.float64(1e-9)
max_relative_diff = upValue / downValue
print(f"Max absolute difference: {max_absolute_diff}\n"
f"Max relative difference: {max_relative_diff}")
return max_absolute_diff, max_relative_diff
def main(): def main():
nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() nnodes, nproc_per_node, name, model_path, bs, length, gen_std, data_type = parse_args()
data_type = "default" if data_type == "fp32" else data_type
model = onnx.load(model_path) model = onnx.load(model_path)
# generate standart output # generate standart output
if gen_std: if gen_std:
print("Generate inputs and outputs.") print(f"generate standard data for {name}.")
p = mp.Process(target=generate_input_output, args=[model]) # a small vocabulary size to fit all LLM.
p.start() generate_input_output(model)
p.join()
return return
if nproc_per_node == 1:
# run single process. # run single process.
# use standalone process to isolate cuda. # use standalone process to isolate bang.
print("run model by single MLU.") print("run model by single MLU.")
p = mp.Process(target=start_single, args=(name, model)) # p = mp.Process(target=start_single, args=(name, model, data_type))
p.start() # p.start()
p.join() # p.join()
start_single(name, model, data_type)
return
# run distributed parallel. # run distributed parallel.
world_size = nnodes * nproc_per_node world_size = nnodes * nproc_per_node
print(f"run model by {world_size} MLUs in parallel.") print(f"run model by {world_size} MLU in parallel.")
workers = [ workers = [
mp.Process( mp.Process(
target=start_worker, target=start_worker,
args=(name, world_size, rank, rank % nproc_per_node, model), args=(name, world_size, rank, rank % nproc_per_node, model, data_type),
) )
for rank in range(world_size) for rank in range(world_size)
] ]

View File

@ -0,0 +1,249 @@
import argparse
import torch
import torch_mlu
from transformers import BertModel, BertConfig
from transformers import GPT2Model, GPT2Config
from transformers import OPTModel, OPTConfig
from transformers import AlbertModel, AlbertConfig
from transformers import LlamaModel, LlamaConfig
import time
import numpy as np
import onnx
import sys
import os
from onnx.external_data_helper import convert_model_to_external_data
from onnxsim import simplify
def parse_args():
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
parser.add_argument(
"--model", type=str, choices=["gpt2", "bert", "opt", "llama", "albert"], required=True, help="model type"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--export_onnx",
type=str,
nargs="?",
default=None,
const="./",
help="whether and where to export onnx file",
)
parser.add_argument(
"--type", type=str, choices=["fp32", "fp16", "tf32"], required=True, help="model data type"
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.model,
args.batch_size,
args.length,
args.export_onnx,
args.type
)
def get_model(modelname):
match modelname:
case "albert":
model = AlbertModel.from_pretrained("albert/albert-base-v2")
voc_size = AlbertConfig().vocab_size
case "bert":
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
voc_size = BertConfig().vocab_size
case "gpt2":
model = GPT2Model.from_pretrained("GPT2")
voc_size = GPT2Config().vocab_size
case "opt":
model = OPTModel.from_pretrained("facebook/opt-125m")
voc_size = OPTConfig().vocab_size
case "llama":
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")
voc_size = LlamaConfig().vocab_size
case _:
raise KeyError(modelname)
model = model.eval()
return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len, dtype="fp32"):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
os.makedirs(os.path.dirname("./data/"), exist_ok=True)
np.save("./data/input_0", data)
inputs = torch.from_numpy(data).to("mlu")
torch_model = torch_model.to("mlu")
if dtype == "fp16":
torch_model = torch_model.half()
n_iter = 20
with torch.no_grad():
for _ in range(10):
outputs = torch_model(inputs)
torch.mlu.synchronize()
begin = time.time()
with torch.no_grad():
for _ in range(n_iter):
torch.mlu.synchronize()
outputs = torch_model(inputs)
torch.mlu.synchronize()
torch.mlu.synchronize()
end = time.time()
avg_time = (end - begin) / n_iter
outputs = outputs.last_hidden_state.to("cpu")
print("outputs abs mean:", abs(np.array(outputs)).mean())
print(f"average time: {avg_time}")
# torch.mlu.memory.empty_cache()
np.save("./data/output", np.array(outputs))
print("Save input & output into ./data.")
def export_onnx(modelname, model, data, path, extern=False, dtype="fp32"):
data = data.to("mlu")
model = model.to("mlu")
if dtype == "fp16":
model = model.half()
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
if modelname != "llama":
# use onnxsim to simplify
onnx_model = onnx.load(path)
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
# onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
assert check
add_value_info_for_constants(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
if extern:
extern_path = path.replace('.onnx', '.pb')
if os.path.exists(extern_path):
os.remove(extern_path)
extern_path = extern_path.split("/")[-1]
convert_model_to_external_data(
onnx_model,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(onnx_model, path)
else:
# use third party tool to simplify llama
# reference: https://github.com/luchangli03/onnxsim_large_model/
sys.path.append("onnxsim_large_model")
from onnx_utils import set_onnx_input_shape
from compress_model import SIZE_1MB, compress_onnx_model, uncompress_onnx_model
in_model_path = path
out_model_path = path
if not out_model_path:
out_model_path = in_model_path[:-5] + ".sim.onnx"
if os.path.isdir(out_model_path):
out_model_path = os.path.join(out_model_path, os.path.basename(in_model_path))
onnx_model = onnx.load(in_model_path)
print(f"load model from {in_model_path} success")
size_th_bytes = 1024 * 1024
onnx_model, removed_inits = compress_onnx_model(onnx_model, size_th_bytes=size_th_bytes)
print(f"compress model success")
onnx_model = set_onnx_input_shape(onnx_model, "")
tensor_size_threshold = f"1024KB"
skipped_optimizers = []
skipped_optimizers.append("eliminate_duplicate_initializer")
onnx_model, check = simplify(onnx_model, skipped_optimizers=skipped_optimizers,
tensor_size_threshold=tensor_size_threshold)
if not check:
raise ValueError(f"simplify compressed model {in_model_path} failed")
print(f"simplify model success")
onnx_model = uncompress_onnx_model(onnx_model, removed_inits)
print(f"uncompress model success")
add_value_info_for_constants(onnx_model)
onnx.save(onnx_model, out_model_path, save_as_external_data=True)
def add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def main():
torch.backends.mlu.matmul.allow_tf32 = False
torch.backends.cnnl.allow_tf32 = False
modelname, batchsize, seqlen, export_path, dtype = parse_args()
if dtype == "tf32":
torch.backends.mlu.matmul.allow_tf32 = True
else:
os.environ["CAMBRICON_TF32_OVERRIDE"] = "0"
model, voc_size = get_model(modelname)
if export_path is not None:
filename = "{}_{}_{}_{}.onnx".format(modelname, batchsize, seqlen, dtype)
path = os.path.join(export_path, filename)
if not os.path.exists(path):
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
export_onnx(modelname, model, param, path, True, dtype)
else:
print("Onnx path exists, skipping export.")
run_pytorch(model, voc_size, batchsize, seqlen, dtype)
if __name__ == "__main__":
main()

@ -0,0 +1 @@
Subproject commit cbcf3fbf985a00494b0f136c92eaccd42031bf65

View File

@ -199,6 +199,24 @@ class CastCnnl : public BangKernelWithoutConfig {
dim.data())); dim.data()));
NlCastType = CNNL_CAST_UINT32_TO_INT64; NlCastType = CNNL_CAST_UINT32_TO_INT64;
break; break;
case CastType::Float162Float:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_HALF, dim.size(),
dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, dim.size(),
dim.data()));
NlCastType = CNNL_CAST_HALF_TO_FLOAT;
break;
case CastType::Float2Float16:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, dim.size(),
dim.data()));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_HALF, dim.size(),
dim.data()));
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
break;
default: default:
IT_TODO_HALT(); IT_TODO_HALT();
} }

View File

@ -19,14 +19,16 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
auto inDims = op->getInputs(0)->getDims(); auto inDims = op->getInputs(0)->getDims();
auto fiterDims = op->getInputs(1)->getDims();
auto outDims = op->getOutput()->getDims(); auto outDims = op->getOutput()->getDims();
auto fiterDims = op->getOutput(1)->getDims();
float eps = op->getEps(); float eps = op->getEps();
const int axis = op->getAxis(); const int axis = op->getAxis();
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc; Shape outMeanDims(outDims);
outMeanDims.erase(outMeanDims.begin() + axis);
cnnlTensorDescriptor_t inDesc, fiterDesc, outDesc, outMeanDesc;
checkCnnlError(cnnlCreateTensorDescriptor(&inDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&inDesc));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
@ -39,15 +41,23 @@ class LayerNormCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(
outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()), outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outDims.size(), outDims.data())); outDims.size(), outDims.data()));
checkCnnlError(cnnlCreateTensorDescriptor(&outMeanDesc));
checkCnnlError(cnnlSetTensorDescriptor(
outMeanDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(op->getDType()),
outMeanDims.size(), outMeanDims.data()));
size_t wsSize; size_t wsSize;
cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc, cnnlGetLayerNormOpWorkspaceSize(context->cnnlHandle(), axis, inDesc,
&wsSize); &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
size_t meanSize =
cnnlGetTensorElementNum(outMeanDesc) * op->getDType().getSize();
BangPtr meanData = context->getWorkspace(meanSize);
BangPtr rstdData = context->getWorkspace(meanSize);
cnnlStatus_t stat = cnnlLayerNormForward( cnnlStatus_t stat = cnnlLayerNormForward(
context->cnnlHandle(), inDesc, inputData, axis, fiterDesc, context->cnnlHandle(), inDesc, inputData, axis, fiterDesc,
scaleData, biasData, eps, wsData, wsSize, outDesc, outputData, scaleData, biasData, eps, wsData, wsSize, outDesc, outputData,
inDesc, NULL, NULL); outMeanDesc, meanData, rstdData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -66,6 +66,13 @@ class MatmulCnnl : public BangKernelWithoutConfig {
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB, cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_DESC_TRANSB, &transB,
sizeof(int32_t)); sizeof(int32_t));
std::string computeTypeStr = op->getComputeType();
if (computeTypeStr == "tf32") {
int32_t tf32 = 1;
cnnlSetMatMulDescAttr(bmm_desc, CNNL_MATMUL_ALLOW_TF32, &tf32,
sizeof(int32_t));
}
cnnlMatMulAlgo_t bmm_algo; cnnlMatMulAlgo_t bmm_algo;
cnnlMatMulAlgoCreate(&bmm_algo); cnnlMatMulAlgoCreate(&bmm_algo);