forked from jiuyuan/InfiniTensor
Fix: skip check when Graph is exported to ONNX
This commit is contained in:
parent
a6019e79e3
commit
34ca6bf149
|
@ -37,6 +37,7 @@ class OnnxStub:
|
|||
outputs: Dict[str, backend.Tensor] = {}
|
||||
initializer: Dict[int, TensorProto] = {}
|
||||
handler: backend.GraphHandler
|
||||
disable_check: bool
|
||||
|
||||
@classmethod
|
||||
def from_onnx(cls, model: ModelProto, runtime):
|
||||
|
@ -551,6 +552,7 @@ class OnnxStub:
|
|||
for i, tensor in enumerate(handler.outputs()):
|
||||
ans.inputs["output{}".format(i)] = tensor
|
||||
ans.handler = handler
|
||||
ans.disable_check = True
|
||||
return ans
|
||||
|
||||
def to_onnx(self, name: str) -> ModelProto:
|
||||
|
@ -570,6 +572,11 @@ class OnnxStub:
|
|||
# saves global input tensors
|
||||
initializers: List[TensorProto] = []
|
||||
|
||||
enable_check = False
|
||||
def __init__(self, enable_check):
|
||||
self.enable_check = enable_check
|
||||
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||
|
@ -622,17 +629,20 @@ class OnnxStub:
|
|||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
check_node(node)
|
||||
if self.enable_check:
|
||||
check_node(node)
|
||||
self.nodes.append(node)
|
||||
|
||||
def build(self, name: str) -> ModelProto:
|
||||
graph = make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
)
|
||||
check_graph(graph)
|
||||
if self.enable_check:
|
||||
check_graph(graph)
|
||||
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
if self.enable_check:
|
||||
check_model(model)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -642,7 +652,7 @@ class OnnxStub:
|
|||
|
||||
ops = self.handler.operators() # 图中所有算子(节点)
|
||||
|
||||
ctx = Context()
|
||||
ctx = Context(not self.disable_check)
|
||||
|
||||
for op in ops:
|
||||
ty, name = ctx.name_op(op)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
#ifdef USE_CUDA
|
||||
#include "core/blob.h"
|
||||
#include "core/dummy_mutator.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "core/search_engine.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "ffi/ffi_callback.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/conv.h"
|
||||
|
@ -10,10 +12,6 @@
|
|||
#include "test.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
namespace infini {
|
||||
|
||||
// NHWC format
|
||||
|
@ -149,7 +147,7 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) {
|
|||
dbg(go0->equalData(bgo0, 1e-3));
|
||||
dbg(runtime->getPerfTime(bestGraph, true));
|
||||
dbg(runtime->timeNonCtcOperators(bestGraph));
|
||||
dbg(runtime->timeWithCudaGraph(bestGraph));
|
||||
// dbg(runtime->timeWithCudaGraph(bestGraph));
|
||||
}
|
||||
|
||||
dbg("Best graph");
|
||||
|
@ -160,7 +158,6 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) {
|
|||
}
|
||||
|
||||
vector<Tensor> runInfoGAN(int nLayers) {
|
||||
#ifdef USE_CUDA
|
||||
auto cuda = make_ref<CudaRuntimeObj>();
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
@ -235,10 +232,10 @@ vector<Tensor> runInfoGAN(int nLayers) {
|
|||
callback::exportONNX(bestGraph, "best_graph.onnx"); // Debug
|
||||
return {g->getOutputs()[0], bestGraph->getOutputs()[0]};
|
||||
}
|
||||
#endif
|
||||
return {};
|
||||
}
|
||||
|
||||
// TEST(ModelE2E, InfoGAN) { runInfoGAN(); }
|
||||
|
||||
} // namespace infini
|
||||
#endif
|
||||
|
|
|
@ -3,7 +3,9 @@ import torch
|
|||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import infinitensor as ft
|
||||
import pyinfinitensor as pit
|
||||
from pyinfinitensor import backend as ft
|
||||
from pyinfinitensor.onnx import OnnxStub
|
||||
|
||||
|
||||
def to_pytorch_tensor(tensor) -> torch.Tensor:
|
||||
|
@ -81,8 +83,10 @@ def run_InfoGAN_without_tuning(tuning: bool):
|
|||
g = ft.getInfoGAN(1, runtime, 5)
|
||||
# g = ft.getInfoGAN(1, runtime, 1)
|
||||
opt_g = ft.optimizeGraph(g, runtime, tuning)
|
||||
ft.if_onnx.export_onnx(opt_g, 'infogan_transformed.onnx')
|
||||
ft.NMutator.memboundToJson(opt_g, ".")
|
||||
stub = OnnxStub.from_graph(opt_g)
|
||||
with open("optimized.onnx", "wb") as f:
|
||||
f.write(stub.to_onnx("optimized").SerializeToString())
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue