forked from jiuyuan/InfiniTensor
Add: C++ callback to export ONNX
This commit is contained in:
parent
872f3504a9
commit
edf4e33353
|
@ -0,0 +1,9 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "core/mutator.h"
|
||||
#include "core/search_engine.h"
|
||||
|
||||
namespace infini {
|
||||
namespace callback {
|
||||
void exportONNX(const Graph &graph, const string &path);
|
||||
}
|
||||
} // namespace infini
|
|
@ -1,8 +1,7 @@
|
|||
import sys
|
||||
import backend
|
||||
from backend import *
|
||||
import sys
|
||||
|
||||
sys.path.extend(__path__)
|
||||
|
||||
import backend
|
||||
print("import backend: {}".format(backend))
|
||||
from backend import *
|
||||
|
||||
|
|
|
@ -541,7 +541,7 @@ class OnnxStub:
|
|||
# for output in model.graph.output:
|
||||
# self.outputs[output.name] = tensors[output.name]
|
||||
|
||||
def to_onnx(self, g: backend.Graph, name: str) -> ModelProto:
|
||||
def to_onnx(self, g: backend.Graph, path: str, name: str = 'my_onnx') -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Union[backend.Tensor, backend.Operator], str] = dict()
|
||||
|
@ -860,7 +860,7 @@ class OnnxStub:
|
|||
raise Exception("Unsupported OpType", ty)
|
||||
|
||||
model = ctx.build(name)
|
||||
onnx.save(model, '/home/zly/InfiniTensor_merge/build/a.onnx')
|
||||
onnx.save(model, path)
|
||||
return model
|
||||
|
||||
# def init(self) -> None:
|
||||
|
@ -935,3 +935,7 @@ class OnnxStub:
|
|||
|
||||
# def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
# return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
||||
def export_onnx(g: backend.Graph, path: str) -> None:
|
||||
stub = OnnxStub()
|
||||
stub.to_onnx(g, path)
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
#include "core/graph.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace infini {
|
||||
|
||||
namespace callback {
|
||||
|
||||
using namespace py::literals;
|
||||
|
||||
static std::function<void(const Graph &, string)> exportONNXImpl;
|
||||
void exportONNX(const Graph &graph, const string &path) {
|
||||
IT_ASSERT(Py_IsInitialized(), "Python interpreter is not running.");
|
||||
static auto exportONNXImpl =
|
||||
py::module_::import("infinitensor.if_onnx").attr("export_onnx");
|
||||
exportONNXImpl(graph, path);
|
||||
}
|
||||
|
||||
} // namespace callback
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue