Add: enable mutator search in python

This commit is contained in:
Liyan Zheng 2023-06-25 20:18:18 +08:00
parent d25b606e12
commit c6c445991a
12 changed files with 223 additions and 143 deletions

View File

@ -19,12 +19,15 @@ class NMutator : public Mutator {
// If in RuleBased mode, use derivationRules in derivator
const std::vector<int> derivationRules;
bool searchFilter = false;
bool enableRules = false; // Enable operator-level transformation rules
public:
NMutator(Mode mode = Mode::Normal,
Runtime runtime = NativeCpuRuntimeObj::getInstance());
Runtime runtime = NativeCpuRuntimeObj::getInstance(),
bool enableRules = false);
NMutator(Mode mode, const std::vector<int> &derivationRules,
Runtime runtime = NativeCpuRuntimeObj::getInstance());
Runtime runtime = NativeCpuRuntimeObj::getInstance(),
bool enableRules = false);
~NMutator();
vector<Graph> run(const Graph &in_graph) override;

View File

@ -16,6 +16,7 @@ void initializeGraphTensors(Graph g, double l, double r, bool useInt);
Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG);
Graph optimizeWithDepthConstraint(Graph g, Runtime _runtime, int maxDepth);
Graph optimizeModel(Graph g, Runtime _runtime, string name);
Graph optimizeModelWithRules(Graph g, Runtime _runtime, vector<int> rules);
} // namespace infini

View File

@ -82,10 +82,11 @@ Tensor GraphHandlerObj::convTransposed2dNHWC(Tensor input, Tensor weight,
oph, opw);
return output;
} else {
return g->addOp<ConvTransposed2dNHWCObj>(std::move(input),
std::move(weight), output, ph,
pw, sh, sw, dh, dw, oph, opw)
->getOutput();
return g
->addOp<ConvTransposed2dNHWCObj>(std::move(input),
std::move(weight), output, ph, pw,
sh, sw, dh, dw, oph, opw)
->getOutput();
}
}

View File

@ -134,7 +134,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling,
t->freeData();
}
// FIXME: ignore trnapose when necessary
// FIXME: ignore trnapose when necessary
// op->getOpType() != OpType::Transpose &&
// op->getOpType() != OpType::ReduceMean
if (op->getOpType() != OpType::Reshape)

View File

@ -357,25 +357,30 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
auto mutatedGraphs = mutator->run(node.graph);
if (mutator->hasTunedKernel)
chooseBestMutation = false;
if (searchFilter == 1) {
std::sort(mutatedGraphs.begin(), mutatedGraphs.end(),
graphTimeComparer);
if (mutatedGraphs.size() >= 10)
mutatedGraphs.resize(10);
mutatedGraphs = {mutatedGraphs[0]};
} else if (chooseBestMutation && mutatedGraphs.size() >= 2) {
std::sort(mutatedGraphs.begin(), mutatedGraphs.end(),
graphTimeComparer);
if (mutatedGraphs.size() >= 10)
mutatedGraphs.resize(10);
mutatedGraphs = {mutatedGraphs[0]};
} else { // avoid repeated kernel genreation
if (mutatedGraphs.size() >= 2) // INFOGAN
mutatedGraphs = {mutatedGraphs[1]};
// if (mutatedGraphs.size() > 2) {
// mutatedGraphs.resize(2);
// }
}
std::sort(mutatedGraphs.begin(), mutatedGraphs.end(),
graphTimeComparer);
if (mutatedGraphs.size() >= 10)
mutatedGraphs.resize(10);
mutatedGraphs = {mutatedGraphs[0]};
// if (searchFilter == 1) {
// std::sort(mutatedGraphs.begin(), mutatedGraphs.end(),
// graphTimeComparer);
// if (mutatedGraphs.size() >= 10)
// mutatedGraphs.resize(10);
// mutatedGraphs = {mutatedGraphs[0]};
// } else if (chooseBestMutation && mutatedGraphs.size() >= 2) {
// std::sort(mutatedGraphs.begin(), mutatedGraphs.end(),
// graphTimeComparer);
// if (mutatedGraphs.size() >= 10)
// mutatedGraphs.resize(10);
// mutatedGraphs = {mutatedGraphs[0]};
// } else { // avoid repeated kernel genreation
// if (mutatedGraphs.size() >= 2) // INFOGAN
// mutatedGraphs = {mutatedGraphs[1]};
// // if (mutatedGraphs.size() > 2) {
// // mutatedGraphs.resize(2);
// // }
// }
for (auto graph : graphs) {
for (auto mutatedGraph : mutatedGraphs) {

View File

@ -432,7 +432,8 @@ void export_test_model(py::module &m) {
"l"_a = -0.1, "r"_a = 0.1, "useInt"_a = false)
.def("convertNCHWtoNHWCModel", &convertNCHWtoNHWCModel)
.def("optimizeWithDepthConstraint", &optimizeWithDepthConstraint)
.def("optimizeModel", &optimizeModel);
.def("optimizeModel", &optimizeModel)
.def("optimizeModelWithRules", &optimizeModelWithRules);
#endif
}

View File

@ -1,7 +1,6 @@
#include "operators/matmul.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "nnet/dbg.h"
namespace infini {

View File

@ -358,12 +358,20 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
return g;
}
Graph optimizeModelWithRules(Graph g, Runtime _runtime, vector<int> rules) {
auto runtime = as<CudaRuntimeObj>(_runtime);
// make_ref<NMutator>(NMutator::Mode::RuleBased, metaRules, runtime);
Ref<NMutator> mutator =
make_ref<NMutator>(NMutator::Mode::RuleBased, rules, runtime);
vector<Graph> bestGraphs;
SearchEngine searchEngine(runtime, mutator);
g->dataFree();
return searchEngine.run(g);
}
Graph optimizeModel(Graph g, Runtime _runtime, string name) {
auto runtime = as<CudaRuntimeObj>(_runtime);
Runtime cpu = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(cpu);
Ref<NMutator> mutator =
make_ref<NMutator>(NMutator::Mode::RuleBased, metaRules, runtime);
Ref<NMutator> mutator = make_ref<NMutator>(NMutator::Mode::Normal, runtime);
vector<Graph> bestGraphs;
SearchEngine searchEngine(runtime, mutator);
g->dataFree();

View File

@ -549,7 +549,9 @@ void Derivator::printDerivationRules() {
if (!startGuided && ruleId != 4)
++cntNonGuideRules;
}
printf("#Steps w/o converging derivation %d, #Steps w/ converging derivation %d\n", cntRules, cntNonGuideRules);
printf("#Steps w/o converging derivation %d, #Steps w/ converging "
"derivation %d\n",
cntRules, cntNonGuideRules);
exit(0);
}

View File

@ -22,15 +22,15 @@
namespace infini {
NMutator::NMutator(Mode mode, Runtime runtime)
: Mutator(10, runtime), mode{mode} {
NMutator::NMutator(Mode mode, Runtime runtime, bool enableRules)
: Mutator(10, runtime), mode{mode}, enableRules{enableRules} {
IT_ASSERT(mode != Mode::RuleBased, "Specify rules for the RuleBased mode.");
}
NMutator::NMutator(Mode mode, const std::vector<int> &derivationRules,
Runtime runtime)
: Mutator(10, runtime), mode{Mode::RuleBased}, derivationRules{
derivationRules} {
Runtime runtime, bool enableRules)
: Mutator(10, runtime), mode{Mode::RuleBased},
derivationRules{derivationRules}, enableRules{enableRules} {
IT_ASSERT(mode == Mode::RuleBased);
}
@ -94,32 +94,38 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
OpVec computeOps = in_graph->getComputeOps();
IT_ASSERT(computeOps.size() == 1);
if (Graph g = transformConvtransposed1x1(computeOps[0])) {
out_graphs.emplace_back(g);
printf("Mutator states enableRules = %d, mode = %d\n", int(enableRules),
int(mode));
if (enableRules) {
// TODO: unify rules
if (Graph g = transformConvtransposed1x1(computeOps[0])) {
out_graphs.emplace_back(g);
}
for (auto g : transformConv1x1(computeOps[0]))
out_graphs.emplace_back(g);
for (auto g : transformConv1xk(computeOps[0]))
out_graphs.emplace_back(g);
for (auto g : transformConv3x3ONNX(computeOps[0]))
out_graphs.emplace_back(g);
if (Graph g = transformG2bmm(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (Graph g = transformGbmm(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (infini::Graph g = transformDialtedConv(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (infini::Graph g = transformConvToGEMMReduce(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (infini::Graph g =
transformConvTranposeToGEMMReduce(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (out_graphs.size() > 1)
return;
}
for (auto g : transformConv1x1(computeOps[0]))
out_graphs.emplace_back(g);
for (auto g : transformConv1xk(computeOps[0]))
out_graphs.emplace_back(g);
for (auto g : transformConv3x3ONNX(computeOps[0]))
out_graphs.emplace_back(g);
if (Graph g = transformG2bmm(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (Graph g = transformGbmm(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (infini::Graph g = transformDialtedConv(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (infini::Graph g = transformConvToGEMMReduce(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (infini::Graph g = transformConvTranposeToGEMMReduce(computeOps[0])) {
out_graphs.emplace_back(g);
}
if (out_graphs.size() > 1)
return;
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC, OpType::G2BMM,
OpType::GBMM};
@ -140,7 +146,7 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
} else
IT_TODO_HALT_MSG("Unknown NMutator search mode.");
const auto &candidates = derivator.getCandidates();
// dbg(candidates.size());
dbg(candidates.size());
// derivator.print();
for (const auto &candidate : candidates) {
// dbg(nnet::FullPrinterVisitor().print(candidate.root));

View File

@ -89,10 +89,10 @@ TEST(cuDNN_Conv, run) {
TEST(cuDNN_Conv, runNHWC) {
testConvNHWCCudnn(OneGenerator(),
vector<float>{12., 12., 12., 12., 18., 18., 18., 18.});
vector<float>{12., 12., 12., 12., 18., 18., 18., 18.});
testConvNHWCCudnn(
IncrementalGenerator(),
vector<float>{3350, 7562, 2306, 5546, 9480, 24546, 7185, 20793});
vector<float>{3350, 7562, 2306, 5546, 9480, 24546, 7185, 20793});
}
TEST(cuDNN_Conv, tune) {

View File

@ -6,6 +6,7 @@ import pandas as pd
import pyinfinitensor as pit
from pyinfinitensor import backend as ft
from pyinfinitensor.onnx import OnnxStub
from pyinfinitensor.tensorrt_backend import get_trt_time
def to_pytorch_tensor(tensor) -> torch.Tensor:
@ -28,9 +29,11 @@ def load_onnx(runtime, filename: str) -> ft.Graph:
def run_and_evaluate(runtime, g):
ft.initializeGraphTensors(g)
runtime.run(g, True)
# print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
# print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}')
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 10)}')
print(f'Op perf time = {runtime.getPerfTime(g, True, False, False)}')
print(f'Graph perf time = {runtime.timeNonCtcOperators(g, 10, 10)}')
t = runtime.timeWithCudaGraph(g, 100)
print(f'Cuda graph time = {t}')
return t
def run_graph_get_output_as_torch_tensor(runtime, g):
@ -101,16 +104,32 @@ def construct_convTranspose2d(runtime, n, c, h, w, f, r, s, pad, stride, dilatio
return handler.getGraph()
def construct_conv(runtime, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw):
def construct_gemm(runtime, b, m, n, k, transA, transB):
handler = ft.GraphHandler(runtime)
input = handler.tensor([b, k, m] if transA else [b, m, k],
tensor_type=ft.TensorType.Input)
w = handler.tensor([b, n, k] if transB else [b, k, n],
tensor_type=ft.TensorType.Initialized)
handler.matmul(input, w, None, transA, transB, None, ft.Linear)
return handler.getGraph()
def construct_conv(runtime, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, bias=False, relu=False):
handler = ft.GraphHandler(runtime)
# input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
# w = handler.tensor([12, 56, 1, 1], tensor_type=ft.TensorType.Initialized)
# handler.conv(input, w, None, 0, 0, 1, 1, 1, 1)
input = handler.tensor([n, c, h, w], tensor_type=ft.TensorType.Input)
w = handler.tensor([f, c, r, s], tensor_type=ft.TensorType.Initialized)
handler.conv(input, w, None, ph, pw, sh, sw, dh, dw)
x = handler.conv(input, w, None, ph, pw, sh, sw, dh, dw)
if bias:
bias = handler.tensor([f, 1, 1], tensor_type=ft.TensorType.Initialized)
x = handler.add(x, bias, None)
if relu:
x = handler.relu(x, None)
return handler.getGraph()
def construct_conv_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
handler = ft.GraphHandler(runtime)
# input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
@ -118,14 +137,17 @@ def construct_conv_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
# handler.conv(input, w, None, 0, 0, 1, 1, 1, 1)
input = handler.tensor([n, h, w, c], tensor_type=ft.TensorType.Input)
w = handler.tensor([f, r, s, c], tensor_type=ft.TensorType.Initialized)
handler.convNHWC(input, w, None, pad, pad, stride, stride, dilation, dilation)
handler.convNHWC(input, w, None, pad, pad, stride,
stride, dilation, dilation)
return handler.getGraph()
def construct_convtranposed_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
handler = ft.GraphHandler(runtime)
input = handler.tensor([n, h, w, c], tensor_type=ft.TensorType.Input)
w = handler.tensor([f, r, s, c], tensor_type=ft.TensorType.Initialized)
handler.convtransposed2dNHWC(input, w, None, pad, pad, stride, stride, dilation, dilation)
handler.convtransposed2dNHWC(
input, w, None, pad, pad, stride, stride, dilation, dilation)
return handler.getGraph()
@ -160,33 +182,112 @@ def search_depth_exp():
# print(f'getPerfTime = {runtime.getPerfTime(g, True, True, False)}')
# print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}')
# save_onnx(g, f"opt_{name}_depth{i}.onnx")
print(f'{name} Depth = {i}: {runtime.getPerfTime(g, True, True, False)} ms')
def model_e2e_exp():
print(
f'{name} Depth = {i}: {runtime.getPerfTime(g, True, True, False)} ms')
def get_e2e_time(runtime, g, name: str):
if name.startswith('resnet'):
return get_trt_time(g)
else:
return run_and_evaluate(runtime, g)
def model_e2e_exp(allow_tf32: bool):
runtime = ft.cuda_runtime()
model_evaluation =[
(lambda : ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'),
(lambda : ft.getGANGraph(16, runtime, 5, 0), 'InfoGAN.bs16'),
(lambda : ft.getGANGraph(1, runtime, 5, 1), 'DCGAN.bs1'),
(lambda : ft.getGANGraph(16, runtime, 5, 1), 'DCGAN.bs16'),
(lambda : ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
(lambda : ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs16.onnx'), 'gcn.bs16'),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs1.onnx'), 'csrnet.bs1'),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs16.onnx'), 'csrnet.bs16'),
(lambda : ft.getLongformer(runtime, 1), 'longformer.bs1'),
(lambda : ft.getLongformer(runtime, 16), 'longformer.bs16'),
]
runtime.setEnableTF32(allow_tf32)
model_evaluation = [
# (lambda: construct_conv(runtime, 1, 512, 7,
# 7, 512, 3, 3, 1, 1, 1, 1, 1, 1), 'ResNet-conv3x3'),
# (lambda: construct_conv(runtime, 1, 512, 7,
# 7, 512, 3, 3, 1, 1, 1, 1, 1, 1, True, True), 'ResNet-conv3x3-BiasRelu'),
# (lambda: construct_conv(runtime, 1, 1, 7,
# 7, 1, 3, 3, 1, 1, 1, 1, 1, 1), 'ResNet-conv3x3-c1'),
# (lambda: construct_conv(runtime, 1, 3, 7,
# 7, 3, 3, 3, 1, 1, 1, 1, 1, 1), 'ResNet-conv3x3-c3'),
# (lambda: construct_conv(runtime, 1, 32, 7,
# 7, 32, 3, 3, 1, 1, 1, 1, 1, 1), 'ResNet-conv3x3-c32'),
# (lambda: construct_conv(runtime, 1, 128, 7,
# 7, 128, 3, 3, 1, 1, 1, 1, 1, 1), 'ResNet-conv3x3-c128'),
# (lambda: ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'),
# (lambda: ft.getGANGraph(16, runtime, 5, 0), 'InfoGAN.bs16'),
# (lambda: ft.getGANGraph(1, runtime, 5, 1), 'DCGAN.bs1'),
# (lambda: ft.getGANGraph(16, runtime, 5, 1), 'DCGAN.bs16'),
# (lambda: ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
# (lambda: ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
# (lambda: load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'),
# (lambda: load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs16.onnx'), 'gcn.bs16'),
(lambda: load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs1.onnx'), 'resnet.bs1'),
# (lambda: load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs16.onnx'), 'resnet.bs16'),
# (lambda: load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs1.onnx'), 'csrnet.bs1'),
# (lambda: load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs16.onnx'), 'csrnet.bs16'),
# (lambda : ft.getLongformer(runtime, 1), 'longformer.bs1'),
# (lambda : ft.getLongformer(runtime, 16), 'longformer.bs16'),
# (lambda : load_onnx(runtime, '/home/whj/workspace/InfiniTensor/cuda-build/efficientnet-b1_bs1.onnx'), 'efficientnet.b1'),
# (lambda : load_onnx(runtime, '/home/whj/workspace/InfiniTensor/cuda-build/mobilenet_v2_bs1.onnx'), 'mobilenet_v2.bs1'),
]
print("Figure 12")
for graph_ctor, name in model_evaluation:
t_orig, t_opt = 99999999, 99999999
print(f"=== {name}")
original_g = graph_ctor()
# original_g = ft.convertNCHWtoNHWCModel(runtime, original_g)
# save_onnx(original_g, f"orig_{name}.onnx")
# print('Time:', get_e2e_time(runtime, original_g, name))
t_orig = run_and_evaluate(runtime, original_g)
g = ft.optimizeModel(original_g, runtime, name)
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
# [3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90]) # Convtranspose2gemm
# save_onnx(g, f"opt_{name}.onnx")
run_and_evaluate(runtime, g)
# g = ft.optimizeModelWithRules(original_g, runtime,
# [3, 2, 2, 5, 8, 8, 6, 90]) # Conv2Gemm
save_onnx(g, f"opt_{name}.onnx")
# run_and_evaluate(runtime, g)
# print(get_e2e_time(runtime, g, name))
t_opt = run_and_evaluate(runtime, g)
print(
f'=== {name} orig/opt=speedup {t_orig:.3f} {t_opt:.3f} {t_orig/t_opt:.2f}')
verify_graphs(runtime, original_g, g)
def test_gemm_tf32(allow_tf32: bool):
configs = [
[1, 1024, 196, 85],
[1, 128, 3136, 256],
[1, 128, 784, 512],
[1, 196, 231, 1024],
[1, 196, 231, 21],
[1, 196, 425, 1024],
[1, 196, 896, 1024],
[1, 196, 896, 128],
[1, 2048, 49, 128],
[1, 21, 50176, 21],
[1, 231, 3136, 21],
[1, 231, 3136, 256],
[1, 256, 3136, 64],
[1, 425, 196, 1024],
[1, 425, 196, 85],
[1, 425, 784, 512],
[1, 49, 231, 2048],
[1, 49, 231, 21],
[1, 49, 896, 128],
[1, 512, 784, 128],
[1, 64, 3136, 256],
[1, 784, 231, 21],
[1, 784, 231, 512],
[1, 896, 196, 128],
[1, 896, 49, 2048],
]
runtime = ft.cuda_runtime()
runtime.setEnableTF32(allow_tf32)
for config in configs:
for transA, transB in ((False, False), (False, True), (True, False), (True, True)):
s = 16
align_config = [config[0], config[1]*16, config[2], config[3]]
align_config = [config[0]]+[(v+s-1)//s*s for v in align_config[1:]]
# align_config = config
g = construct_gemm(runtime, *align_config, transA, transB)
print(
f"{allow_tf32} {transA} {transB} {align_config} {run_and_evaluate(runtime, g)}")
def perf_test():
@ -196,56 +297,9 @@ def perf_test():
g = ft.getLongformer(runtime, 1)
run_and_evaluate(runtime, g)
if __name__ == "__main__":
# perf_test()
model_e2e_exp()
exit()
runtime = ft.cuda_runtime()
graphs = [
# (construct_conv(runtime, 16, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1'), # FSRCNN Conv_2 1x1
# (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 3, 1, 1, 1), 'conv3x3'), # FSRCNN Conv_4 3x3
# (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 1, 1, 0, 1, 1, 1, 1), 'conv3x1'), #
# (construct_conv(runtime, 1, 12, 32, 32, 12, 1, 11, 0, 5, 1, 1, 1, 1), 'conv1x11'), #
# (construct_conv(runtime, 16, 12, 32, 32, 12, 1, 11, 0, 5, 1, 1, 1, 1), 'conv1x11_bs16'), #
# (construct_conv(runtime, 16,32,224,224, 1, 5, 5, 2, 2, 1, 1, 1, 1), 'conv5x5'), #
# (ft.getLongformer(runtime, 1), 'longformer.bs1'),
# (ft.getLongformer(runtime, 16), 'longformer.bs16'),
# construct_convTranspose2d(runtime)
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
# (ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
# (ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
# (construct_conv_nhwc(runtime, 1, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1')
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs16.onnx'), 'gcn.bs16'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs1.onnx'), 'csrnet.bs1'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs16.onnx'), 'csrnet.bs16'),
(ft.getLongformer(runtime, 1), 'longformer.bs1'),
# (ft.getLongformer(runtime, 16), 'longformer.bs16'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs1.onnx'), 'resnet18.bs1'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs16.onnx'), 'resnet18.bs16'),
# (ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'),
]
for original_g, name in graphs:
print(f"=== {name}")
# save_onnx(original_g, f"orig_{name}.onnx")
# original_g = ft.convertNCHWtoNHWCModel(runtime, original_g)
# save_onnx(dlt_g, f"dlt_{name}.onnx")
# exit()
# run_and_evaluate(runtime, original_g)
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
# [1, 7, 7, 2, 8, 6, 6]) # G2BMM/GBMM
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
# [3, 2, 2, 5, 8, 8, 6, 90]) # Conv2conv
g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
[3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90]) # Convtranspose2gemm
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.Normal)
# g = ft.convertNCHWtoNHWCModel(original_g, runtime, i)
# run_and_evaluate(runtime, original_g)
run_and_evaluate(runtime, g)
save_onnx(g, f"opt_{name}.onnx")
# verify_graphs(runtime, original_g, g)
# run_and_evaluate(runtime, g)
for b in [False]:
model_e2e_exp(b)
# test_gemm_tf32(b)