forked from jiuyuan/InfiniTensor
Add: C++ callback to export ONNX
This commit is contained in:
parent
872f3504a9
commit
edf4e33353
|
@ -0,0 +1,9 @@
|
||||||
|
#include "core/graph_handler.h"
|
||||||
|
#include "core/mutator.h"
|
||||||
|
#include "core/search_engine.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
namespace callback {
|
||||||
|
void exportONNX(const Graph &graph, const string &path);
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -1,8 +1,7 @@
|
||||||
import sys
|
import backend
|
||||||
|
from backend import *
|
||||||
|
import sys
|
||||||
|
|
||||||
sys.path.extend(__path__)
|
sys.path.extend(__path__)
|
||||||
|
|
||||||
import backend
|
|
||||||
print("import backend: {}".format(backend))
|
print("import backend: {}".format(backend))
|
||||||
from backend import *
|
|
||||||
|
|
||||||
|
|
|
@ -541,7 +541,7 @@ class OnnxStub:
|
||||||
# for output in model.graph.output:
|
# for output in model.graph.output:
|
||||||
# self.outputs[output.name] = tensors[output.name]
|
# self.outputs[output.name] = tensors[output.name]
|
||||||
|
|
||||||
def to_onnx(self, g: backend.Graph, name: str) -> ModelProto:
|
def to_onnx(self, g: backend.Graph, path: str, name: str = 'my_onnx') -> ModelProto:
|
||||||
class Context:
|
class Context:
|
||||||
# saves object names, including tensors and operators
|
# saves object names, including tensors and operators
|
||||||
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
|
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
|
||||||
|
@ -860,7 +860,7 @@ class OnnxStub:
|
||||||
raise Exception("Unsupported OpType", ty)
|
raise Exception("Unsupported OpType", ty)
|
||||||
|
|
||||||
model = ctx.build(name)
|
model = ctx.build(name)
|
||||||
onnx.save(model, '/home/zly/InfiniTensor_merge/build/a.onnx')
|
onnx.save(model, path)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# def init(self) -> None:
|
# def init(self) -> None:
|
||||||
|
@ -935,3 +935,7 @@ class OnnxStub:
|
||||||
|
|
||||||
# def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
# def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||||
# return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
# return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||||
|
|
||||||
|
def export_onnx(g: backend.Graph, path: str) -> None:
|
||||||
|
stub = OnnxStub()
|
||||||
|
stub.to_onnx(g, path)
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
namespace callback {
|
||||||
|
|
||||||
|
using namespace py::literals;
|
||||||
|
|
||||||
|
static std::function<void(const Graph &, string)> exportONNXImpl;
|
||||||
|
void exportONNX(const Graph &graph, const string &path) {
|
||||||
|
IT_ASSERT(Py_IsInitialized(), "Python interpreter is not running.");
|
||||||
|
static auto exportONNXImpl =
|
||||||
|
py::module_::import("infinitensor.if_onnx").attr("export_onnx");
|
||||||
|
exportONNXImpl(graph, path);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace callback
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue