Add: C++ callback to export ONNX

This commit is contained in:
Liyan Zheng 2023-04-17 11:48:26 +08:00
parent 872f3504a9
commit edf4e33353
4 changed files with 40 additions and 6 deletions

View File

@ -0,0 +1,9 @@
#include "core/graph_handler.h"
#include "core/mutator.h"
#include "core/search_engine.h"
namespace infini {
namespace callback {
void exportONNX(const Graph &graph, const string &path);
}
} // namespace infini

View File

@ -1,8 +1,7 @@
import sys
import backend
from backend import *
import sys
sys.path.extend(__path__)
import backend
print("import backend: {}".format(backend))
from backend import *

View File

@ -541,7 +541,7 @@ class OnnxStub:
# for output in model.graph.output:
# self.outputs[output.name] = tensors[output.name]
def to_onnx(self, g: backend.Graph, name: str) -> ModelProto:
def to_onnx(self, g: backend.Graph, path: str, name: str = 'my_onnx') -> ModelProto:
class Context:
# saves object names, including tensors and operators
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
@ -860,7 +860,7 @@ class OnnxStub:
raise Exception("Unsupported OpType", ty)
model = ctx.build(name)
onnx.save(model, '/home/zly/InfiniTensor_merge/build/a.onnx')
onnx.save(model, path)
return model
# def init(self) -> None:
@ -935,3 +935,7 @@ class OnnxStub:
# def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
# return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
def export_onnx(g: backend.Graph, path: str) -> None:
stub = OnnxStub()
stub.to_onnx(g, path)

22
src/ffi/ffi_callback.cc Normal file
View File

@ -0,0 +1,22 @@
#include "core/graph.h"
#include <pybind11/stl.h>
namespace py = pybind11;
namespace infini {
namespace callback {
using namespace py::literals;
static std::function<void(const Graph &, string)> exportONNXImpl;
void exportONNX(const Graph &graph, const string &path) {
IT_ASSERT(Py_IsInitialized(), "Python interpreter is not running.");
static auto exportONNXImpl =
py::module_::import("infinitensor.if_onnx").attr("export_onnx");
exportONNXImpl(graph, path);
}
} // namespace callback
} // namespace infini