From 0833a2f779bf1baefe91458a16fd499aca4c2382 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 17 Feb 2023 17:15:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AF=BC=E5=87=BA=E5=8A=A0=E5=87=8F?= =?UTF-8?q?=E4=B9=98=E9=99=A4=E5=B9=82=E5=88=B0=20onnx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- pyinfinitensor/src/pyinfinitensor/onnx.py | 145 +++++++++++++++++++--- pyinfinitensor/tests/test_onnx.py | 18 ++- src/ffi/ffi_infinitensor.cc | 95 +++++--------- 3 files changed, 169 insertions(+), 89 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 68d0f110..c98fa7cf 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,4 +1,6 @@ -import onnx, backend +import backend +from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto +from onnx.helper import make_node from onnx.shape_inference import infer_shapes from typing import Dict, List, Any from functools import reduce @@ -6,12 +8,12 @@ from functools import reduce runtime = backend.cpu_runtime() -def from_onnx(model: onnx.ModelProto) -> backend.GraphHandler: +def from_onnx(model: ModelProto) -> backend.GraphHandler: model = infer_shapes(model) handler = backend.GraphHandler(runtime) tensors: Dict[str, backend.Tensor] = dict() - data: Dict[str, onnx.TensorProto] = dict() + data: Dict[str, TensorProto] = dict() for input in model.graph.input: dims = _take_shape_dim(input.type.tensor_type.shape) @@ -310,16 +312,123 @@ def to_onnx(graph: backend.GraphHandler): ops = graph.operators() names: Dict[Any, str] = dict() - count: Dict[backend.OpType, int] = dict() + nodes: List[NodeProto] = [] + count_op: Dict[backend.OpType, int] = dict() + count_in = 0 for op in ops: ty = op.op_type() - names[op] = "{}{}".format(ty.name, count.setdefault(ty, 0) + 1) - count[ty] += 1 + name = "{}{}".format(ty.name, count_op.setdefault(ty, 0) + 1) + inputs = op.inputs() + outputs = op.outputs() + names[op] = name + count_op[ty] += 1 + if ty == backend.OpType.Matmul: + raise Exception("TODO") + elif ty == backend.OpType.BatchNorm: + raise Exception("TODO") + elif ty == backend.OpType.MaxPool: + raise Exception("TODO") + elif ty == backend.OpType.AvgPool: + raise Exception("TODO") + elif ty == backend.OpType.Add: + names[outputs[0]] = name + if inputs[0] in names: + a = names[inputs[0]] + else: + count_in += 1 + a = "input{}".format(count_in) + if inputs[1] in names: + b = names[inputs[1]] + else: + count_in += 1 + b = "input{}".format(count_in) + nodes.append(make_node("Add", [a, b], [name], name)) + elif ty == backend.OpType.Sub: + names[outputs[0]] = name + if inputs[0] in names: + a = names[inputs[0]] + else: + count_in += 1 + a = "input{}".format(count_in) + if inputs[1] in names: + b = names[inputs[1]] + else: + count_in += 1 + b = "input{}".format(count_in) + nodes.append(make_node("Sub", [a, b], [name], name)) + elif ty == backend.OpType.Mul: + names[outputs[0]] = name + if inputs[0] in names: + a = names[inputs[0]] + else: + count_in += 1 + a = "input{}".format(count_in) + if inputs[1] in names: + b = names[inputs[1]] + else: + count_in += 1 + b = "input{}".format(count_in) + nodes.append(make_node("Mul", [a, b], [name], name)) + elif ty == backend.OpType.Div: + names[outputs[0]] = name + if inputs[0] in names: + a = names[inputs[0]] + else: + count_in += 1 + a = "input{}".format(count_in) + if inputs[1] in names: + b = names[inputs[1]] + else: + count_in += 1 + b = "input{}".format(count_in) + nodes.append(make_node("Div", [a, b], [name], name)) + elif ty == backend.OpType.Pow: + names[outputs[0]] = name + if inputs[0] in names: + a = names[inputs[0]] + else: + count_in += 1 + a = "input{}".format(count_in) + if inputs[1] in names: + b = names[inputs[1]] + else: + count_in += 1 + b = "input{}".format(count_in) + nodes.append(make_node("Pow", [a, b], [name], name)) + elif ty == backend.OpType.Relu: + raise Exception("TODO") + elif ty == backend.OpType.Sigmoid: + raise Exception("TODO") + elif ty == backend.OpType.Tanh: + raise Exception("TODO") + elif ty == backend.OpType.Softmax: + raise Exception("TODO") + elif ty == backend.OpType.Abs: + raise Exception("TODO") + elif ty == backend.OpType.Identity: + raise Exception("TODO") + elif ty == backend.OpType.Flatten: + raise Exception("TODO") + elif ty == backend.OpType.Reshape: + raise Exception("TODO") + elif ty == backend.OpType.Concat: + raise Exception("TODO") + elif ty == backend.OpType.Gather: + raise Exception("TODO") + elif ty == backend.OpType.ReduceMean: + raise Exception("TODO") + elif ty == backend.OpType.Slice: + raise Exception("TODO") + elif ty == backend.OpType.Pad: + raise Exception("TODO") + else: + raise Exception("Unsupported OpType {}".format(ty.name)) + print(names) -def parse_onnx(model: onnx.ModelProto): +def parse_onnx(model: ModelProto): print() for field in [ @@ -355,34 +464,32 @@ def parse_onnx(model: onnx.ModelProto): print(" {}".format(node.name)) -def _parse_attribute( - node: onnx.NodeProto, attrs: Dict[str, Any] = dict() -) -> Dict[str, Any]: +def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: for attr in node.attribute: if attr.name in attrs: - if attr.type == onnx.AttributeProto.INT: + if attr.type == AttributeProto.INT: attrs[attr.name] = attr.i - elif attr.type == onnx.AttributeProto.INTS: + elif attr.type == AttributeProto.INTS: attrs[attr.name] = attr.ints - elif attr.type == onnx.AttributeProto.FLOAT: + elif attr.type == AttributeProto.FLOAT: attrs[attr.name] = attr.f - elif attr.type == onnx.AttributeProto.STRING: + elif attr.type == AttributeProto.STRING: attrs[attr.name] = attr.s - elif attr.type == onnx.AttributeProto.TENSOR: + elif attr.type == AttributeProto.TENSOR: attrs[attr.name] = attr.t else: assert False, "Unsupported Attribute Type: {}".format(attr.type) return attrs -def _parse_data(tensor: onnx.TensorProto) -> List[int]: - if tensor.data_type == onnx.TensorProto.INT32: +def _parse_data(tensor: TensorProto) -> List[int]: + if tensor.data_type == TensorProto.INT32: return [int(i) for i in tensor.int32_data] - elif tensor.data_type == onnx.TensorProto.INT64: + elif tensor.data_type == TensorProto.INT64: return [int(i) for i in tensor.int64_data] else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) -def _take_shape_dim(shape: onnx.TensorShapeProto) -> List[int]: +def _take_shape_dim(shape: TensorShapeProto) -> List[int]: return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 9547bcac..a28dd5b9 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -294,10 +294,20 @@ class TestStringMethods(unittest.TestCase): def test_frontend(self): handler = backend.GraphHandler(runtime) - i = handler.tensor([1, 2, 3], 12) - w = handler.tensor([1, 3, 4], 12) - o = handler.tensor([1, 2, 4], 12) - handler.matmul(i, w, o, False, False, None, backend.ActType.Relu) + a = handler.tensor([1, 2, 3], 12) + b = handler.tensor([1, 2, 3], 12) + ab = handler.tensor([1, 2, 3], 12) + c = handler.tensor([1, 2, 3], 12) + abc = handler.tensor([1, 2, 3], 12) + d = handler.tensor([1, 2, 3], 12) + abcd = handler.tensor([1, 2, 3], 12) + e = handler.tensor([1, 2, 3], 12) + abcde = handler.tensor([1, 2, 3], 12) + + handler.add(a, b, ab) + handler.add(ab, c, abc) + handler.add(abc, d, abcd) + handler.add(abcd, e, abcde) to_onnx(handler) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 5a86b56a..ceb641d7 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -80,75 +80,38 @@ void init_graph_builder(py::module &m) { py::class_>(m, "TensorObj") .def("src", &TensorObj::getOutputOf, policy::move); py::class_>(m, "Operator") - .def("op_type", &OperatorObj::getOpType, policy::move); + .def("op_type", &OperatorObj::getOpType, policy::automatic) + .def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_), + policy::reference) + .def("outputs", + py::overload_cast<>(&OperatorObj::getOutputs, py::const_), + policy::reference); py::class_(m, "GraphHandler") .def(py::init()) - .def("tensor", py::overload_cast(&Handler::tensor), - policy::move) + .def("tensor", &Handler::tensor, policy::move) .def("conv", &Handler::conv, policy::move) - .def("matmul", - py::overload_cast(&Handler::matmul), - policy::move) - .def("batchNorm", - py::overload_cast(&Handler::batchNorm), - policy::move) - .def("maxPool", - py::overload_cast(&Handler::maxPool), - policy::move) - .def("avgPool", - py::overload_cast(&Handler::avgPool), - policy::move) - .def("add", py::overload_cast(&Handler::add), - policy::move) - .def("sub", py::overload_cast(&Handler::sub), - policy::move) - .def("mul", py::overload_cast(&Handler::mul), - policy::move) - .def("div", py::overload_cast(&Handler::div), - policy::move) - .def("pow", py::overload_cast(&Handler::pow), - policy::move) - .def("relu", py::overload_cast(&Handler::relu), - policy::move) - .def("sigmoid", py::overload_cast(&Handler::sigmoid), - policy::move) - .def("tanh", py::overload_cast(&Handler::tanh), - policy::move) - .def("softmax", py::overload_cast(&Handler::softmax), - policy::move) - .def("abs", py::overload_cast(&Handler::abs), - policy::move) - .def("identity", py::overload_cast(&Handler::identity), - policy::move) - .def("flatten", py::overload_cast(&Handler::flatten), - policy::move) - .def("reshape", - py::overload_cast(&Handler::reshape), - policy::move) - .def("concat", - py::overload_cast(&Handler::concat), - policy::move) - .def("gather", - py::overload_cast(&Handler::gather), - policy::move) - .def("reduceMean", - py::overload_cast> &, - bool>(&Handler::reduceMean), - policy::move) - .def("slice", - py::overload_cast< - Tensor, Tensor, const vector &, const vector &, - const optional> &, const optional> &>( - &Handler::slice), - policy::move) - .def("pad", - py::overload_cast &, - const optional> &>(&Handler::pad), - policy::move) + .def("matmul", &Handler::matmul, policy::move) + .def("batchNorm", &Handler::batchNorm, policy::move) + .def("maxPool", &Handler::maxPool, policy::move) + .def("avgPool", &Handler::avgPool, policy::move) + .def("add", &Handler::add, policy::move) + .def("sub", &Handler::sub, policy::move) + .def("mul", &Handler::mul, policy::move) + .def("div", &Handler::div, policy::move) + .def("pow", &Handler::pow, policy::move) + .def("relu", &Handler::relu, policy::move) + .def("sigmoid", &Handler::sigmoid, policy::move) + .def("tanh", &Handler::tanh, policy::move) + .def("softmax", &Handler::softmax, policy::move) + .def("abs", &Handler::abs, policy::move) + .def("identity", &Handler::identity, policy::move) + .def("flatten", &Handler::flatten, policy::move) + .def("reshape", &Handler::reshape, policy::move) + .def("concat", &Handler::concat, policy::move) + .def("gather", &Handler::gather, policy::move) + .def("reduceMean", &Handler::reduceMean, policy::move) + .def("slice", &Handler::slice, policy::move) + .def("pad", &Handler::pad, policy::move) .def("topo_sort", &Handler::topo_sort, policy::automatic) .def("operators", &Handler::operators, policy::move) .def("data_malloc", &Handler::data_malloc, policy::automatic)