From edf4e3335361b9cde2f4e359df96aea408e459ec Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Mon, 17 Apr 2023 11:48:26 +0800 Subject: [PATCH] Add: C++ callback to export ONNX --- include/ffi/ffi_callback.h | 9 +++++++++ python/infinitensor/__init__.py | 7 +++---- python/infinitensor/if_onnx.py | 8 ++++++-- src/ffi/ffi_callback.cc | 22 ++++++++++++++++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 include/ffi/ffi_callback.h create mode 100644 src/ffi/ffi_callback.cc diff --git a/include/ffi/ffi_callback.h b/include/ffi/ffi_callback.h new file mode 100644 index 00000000..409e3077 --- /dev/null +++ b/include/ffi/ffi_callback.h @@ -0,0 +1,9 @@ +#include "core/graph_handler.h" +#include "core/mutator.h" +#include "core/search_engine.h" + +namespace infini { +namespace callback { +void exportONNX(const Graph &graph, const string &path); +} +} // namespace infini \ No newline at end of file diff --git a/python/infinitensor/__init__.py b/python/infinitensor/__init__.py index 50fff617..1afc546f 100644 --- a/python/infinitensor/__init__.py +++ b/python/infinitensor/__init__.py @@ -1,8 +1,7 @@ -import sys +import backend +from backend import * +import sys sys.path.extend(__path__) -import backend print("import backend: {}".format(backend)) -from backend import * - diff --git a/python/infinitensor/if_onnx.py b/python/infinitensor/if_onnx.py index d05c19fc..1f12db0b 100644 --- a/python/infinitensor/if_onnx.py +++ b/python/infinitensor/if_onnx.py @@ -541,7 +541,7 @@ class OnnxStub: # for output in model.graph.output: # self.outputs[output.name] = tensors[output.name] - def to_onnx(self, g: backend.Graph, name: str) -> ModelProto: + def to_onnx(self, g: backend.Graph, path: str, name: str = 'my_onnx') -> ModelProto: class Context: # saves object names, including tensors and operators names: Dict[Union[backend.Tensor, backend.Operator], str] = dict() @@ -860,7 +860,7 @@ class OnnxStub: raise Exception("Unsupported OpType", ty) model = ctx.build(name) - onnx.save(model, '/home/zly/InfiniTensor_merge/build/a.onnx') + onnx.save(model, path) return model # def init(self) -> None: @@ -935,3 +935,7 @@ class OnnxStub: # def _take_shape_dim(shape: TensorShapeProto) -> List[int]: # return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] + +def export_onnx(g: backend.Graph, path: str) -> None: + stub = OnnxStub() + stub.to_onnx(g, path) diff --git a/src/ffi/ffi_callback.cc b/src/ffi/ffi_callback.cc new file mode 100644 index 00000000..7b804153 --- /dev/null +++ b/src/ffi/ffi_callback.cc @@ -0,0 +1,22 @@ +#include "core/graph.h" +#include + +namespace py = pybind11; + +namespace infini { + +namespace callback { + +using namespace py::literals; + +static std::function exportONNXImpl; +void exportONNX(const Graph &graph, const string &path) { + IT_ASSERT(Py_IsInitialized(), "Python interpreter is not running."); + static auto exportONNXImpl = + py::module_::import("infinitensor.if_onnx").attr("export_onnx"); + exportONNXImpl(graph, path); +} + +} // namespace callback + +} // namespace infini \ No newline at end of file