diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index b74f8e13..dc221042 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -38,12 +38,14 @@ class GraphHandlerObj { Tensor tensor(Shape dims, int dtype); + //------ operators + + inline OpVec operators() { return g->getOperators(); } + Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw, int sh, int sw, int dh, int dw); - Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act); - Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var, Tensor scale, Tensor bias, float momentum, float eps, bool training); @@ -77,6 +79,10 @@ class GraphHandlerObj { Tensor pad(Tensor input, Tensor output, const vector &pads, const optional> &axes); + //------ modifiers + + inline bool topo_sort() { return g->topo_sort(); } + //------ runtime inline void data_malloc() { g->dataMalloc(); } diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a429368f..832e398c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -6,11 +6,11 @@ from functools import reduce runtime = backend.cpu_runtime() -def from_onnx(model: onnx.ModelProto): +def from_onnx(model: onnx.ModelProto) -> backend.GraphHandler: model = infer_shapes(model) - handler = backend.GraphHandlerObj(runtime) + handler = backend.GraphHandler(runtime) - tensors: Dict[str, backend.TensorObj] = dict() + tensors: Dict[str, backend.Tensor] = dict() data: Dict[str, onnx.TensorProto] = dict() for input in model.graph.input: @@ -303,6 +303,13 @@ def from_onnx(model: onnx.ModelProto): raise Exception('Unsupported operator "{}"'.format(node.op_type)) +def to_onnx(graph: backend.GraphHandler): + if not graph.topo_sort(): + raise Exception("Sorting fails") + + ops = graph.operators() + + def parse_onnx(model: onnx.ModelProto): print() diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 46328f76..9547bcac 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -8,7 +8,7 @@ from onnx.helper import ( make_tensor_value_info, ) from onnx.checker import check_model -from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime +from pyinfinitensor.onnx import from_onnx, parse_onnx, backend, runtime, to_onnx def make_and_import_model(graph: onnx.GraphProto): @@ -293,12 +293,14 @@ class TestStringMethods(unittest.TestCase): parse_onnx(model) def test_frontend(self): - handler = backend.GraphHandlerObj(runtime) + handler = backend.GraphHandler(runtime) i = handler.tensor([1, 2, 3], 12) w = handler.tensor([1, 3, 4], 12) o = handler.tensor([1, 2, 4], 12) handler.matmul(i, w, o, False, False, None, backend.ActType.Relu) + to_onnx(handler) + if __name__ == "__main__": unittest.main() diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 37b7d5da..b0de6d08 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -25,17 +25,18 @@ void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; m.def("cpu_runtime", &CpuRuntimeObj::getInstance); - py::class_>(m, "RuntimeObj"); + py::class_>(m, "Runtime"); py::class_, RuntimeObj>( - m, "CpuRuntimeObj"); + m, "CpuRuntime"); py::class_>(m, "TensorObj"); + py::class_>(m, "Operator"); py::enum_(m, "ActType") .value("Linear", ActType::None) // `None` is Python keyword .value("Relu", ActType::Relu) .value("Sigmoid", ActType::Sigmoid) .value("Tanh", ActType::Tanh) .export_values(); - py::class_(m, "GraphHandlerObj") + py::class_(m, "GraphHandler") .def(py::init()) .def("tensor", py::overload_cast(&Handler::tensor), policy::move) @@ -103,6 +104,10 @@ void init_graph_builder(py::module &m) { py::overload_cast &, const optional> &>(&Handler::pad), policy::move) + .def("topo_sort", py::overload_cast<>(&Handler::topo_sort), + policy::automatic) + .def("operators", py::overload_cast<>(&Handler::operators), + policy::move) .def("data_malloc", &Handler::data_malloc, policy::automatic) .def("run", &Handler::run, policy::automatic); }