Add: python verification

This commit is contained in:
Liyan Zheng 2023-04-21 13:07:58 +08:00
parent 0cb8729bc1
commit f0fcbe825f
6 changed files with 92 additions and 27 deletions

View File

@ -0,0 +1,16 @@
#ifdef USE_CUDA
#include "core/graph.h"
#include "core/runtime.h"
#include "core/search_engine.h"
namespace infini {
Graph getInfoGAN(int batch, Runtime runtime, int nLayers);
vector<Tensor> runInfoGAN(int nLayers);
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
Graph optimizeGraph(Graph g, Runtime runtime, bool tuning);
void initializeGraphTensors(Graph g, double l, double r, bool useInt);
} // namespace infini
#endif

View File

@ -878,6 +878,16 @@ class OnnxStub:
ctx.push_data_input(name, "max", TensorProto.FLOAT, [], [])
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.ConvTransNHWC:
ctx.push_node(
make_node(
ty.name,
inputs,
outputs,
name,
domain="nnet",
)
)
elif ty == backend.OpType.MemBound:
ctx.push_node(
make_node(

View File

@ -156,6 +156,7 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph) {
IT_ASSERT(numCudaGraphNodes == kernels.size(),
std::to_string(numCudaGraphNodes) +
" != " + std::to_string(kernels.size()));
printf("numCudaGraphNodes = %lu\n", numCudaGraphNodes);
return timeit(
[&, cudaGraphInstance = cudaGraphInstance, stream = getStream()]() {
checkCudaError(cudaGraphLaunch(cudaGraphInstance, stream));

View File

@ -2,6 +2,7 @@
#include "core/mutator.h"
#include "core/search_engine.h"
#include "nnet/nmutator.h"
#include "nnet/test_models.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
#include "operators/conv.h"
@ -378,20 +379,17 @@ void init_graph_builder(py::module &m) {
.def("topo_sort", &GraphObj::topo_sort);
}
#ifdef USE_CUDA
Graph getInfoGAN(int batch, Runtime runtime, int nLayers);
vector<Tensor> runInfoGAN(int nLayers);
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
Graph optimizeGraph(Graph g, Runtime runtime, bool tuning);
void export_test_model(py::module &m) {
m.def("runInfoGAN", &runInfoGAN);
m.def("getInfoGAN", &getInfoGAN);
m.def("getConvtransposedNHWC", &getConvtransposedNHWC);
m.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
"tuning"_a = false);
}
#ifdef USE_CUDA
m.def("runInfoGAN", &runInfoGAN)
.def("getInfoGAN", &getInfoGAN)
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
"tuning"_a = false)
.def("initializeGraphTensors", &initializeGraphTensors, "g"_a,
"l"_a = -0.1, "r"_a = 0.1, "useInt"_a = false);
#endif
}
} // namespace infini

View File

@ -83,6 +83,16 @@ void printGraph(Graph g) {
}
}
void initializeGraphTensors(Graph g, double l, double r, bool useInt) {
auto gen = RandomGenerator(-0.1, 0.1, 0, useInt);
for (auto t : g->getInputs()) {
t->setData(gen);
}
for (auto t : g->getOutputs()) {
t->setData(ZeroGenerator());
}
}
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) {
auto runtime = as<CudaRuntimeObj>(_runtime);
Runtime cpu = NativeCpuRuntimeObj::getInstance();

View File

@ -71,21 +71,14 @@ def run_e2e_InfoGAN():
df.to_csv('a.csv')
def runSingleConvT():
runtime = ft.cuda_runtime()
g = ft.getConvtransposedNHWC(runtime, [1, 2, 2, 448], 1)
opt_g = ft.optimizeGraph(g, runtime)
ft.if_onnx.export_onnx(opt_g, 'convtransposed.onnx')
def getSingleConvT(runtime):
return ft.getConvtransposedNHWC(runtime, [1, 2, 2, 448], 1)
def run_InfoGAN_without_tuning(runtime, tuning: bool):
g = ft.getInfoGAN(1, runtime, 5)
# g = ft.getInfoGAN(1, runtime, 1)
opt_g = ft.optimizeGraph(g, runtime, tuning)
def save_onnx(opt_g: ft.Graph, filename: str):
stub = OnnxStub.from_graph(opt_g)
with open("optimized.onnx", "wb") as f:
with open(filename, "wb") as f:
f.write(stub.to_onnx("optimized").SerializeToString())
return opt_g
def load_onnx(runtime) -> ft.Graph:
@ -100,14 +93,51 @@ def run_and_evaluate(runtime, g):
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g)}')
def run_graph_get_output(runtime, g):
ft.initializeGraphTensors(g)
runtime.run(g, True)
runtime.run(g, False)
tensors = [to_pytorch_tensor(t) for t in g.outputs()]
assert len(tensors) == 1
return tensors[0]
def compare_tensors(ans, x):
assert ans.shape == x.shape
print(f'Allclose {torch.allclose(ans, x)}')
# Print error numbers
tot = np.product(ans.shape)
data = []
for i in range(0, 10):
tol = 10**(-i)
clo = torch.isclose(ans, x, atol=tol, rtol=tol).sum().item()
print(f'0.1^{i} close: {clo}/{tot} = {clo/tot}')
data.append(clo/tot)
# rel_err = torch.abs((ans-x)/ans)
# print(f'rel_err = {rel_err}')
# print(f'max rel err = {rel_err.max()}')
print(f'ans = {ans}')
print(f'x = {x}')
if __name__ == "__main__":
runtime = ft.cuda_runtime()
# run_e2e_InfoGAN()
# runSingleConvT()
# read_and_check()
runtime = ft.cuda_runtime()
if True:
g = run_InfoGAN_without_tuning(runtime, False)
original_g = ft.getInfoGAN(16, runtime, 5)
# original_g = ft.getConvtransposedNHWC(runtime, [1, 1, 1, 228], 0) # ConvTranspose 2x2
# original_g = ft.getConvtransposedNHWC(runtime, [16, 2, 2, 448], 1) # ConvTranspose 4x4
g = ft.optimizeGraph(original_g, runtime, tuning=False)
else:
g = load_onnx(runtime)
run_and_evaluate(runtime, g)
save_onnx(g, "optimized.onnx")
ans = run_graph_get_output(runtime, original_g)
x = run_graph_get_output(runtime, g)
print('=== 138')
compare_tensors(ans, x)
# run_and_evaluate(runtime, g)