From 4c7fdf44c59b90a2d282657f831435a09c823479 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 22 Feb 2023 15:05:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=89=8D=E7=AB=AF=E6=94=AF=E6=8C=81=20?= =?UTF-8?q?Conv=20=E5=8F=8A=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/graph_handler.h | 3 +++ pyinfinitensor/src/pyinfinitensor/onnx.py | 23 ++++++++++++++++++++++- pyinfinitensor/tests/test_onnx.py | 15 +++++++++++++++ src/core/graph_handler.cc | 15 +++++++++++++++ src/ffi/ffi_infinitensor.cc | 1 + 5 files changed, 56 insertions(+), 1 deletion(-) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 121028ca..22897c4c 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -38,6 +38,9 @@ class GraphHandlerObj { Tensor tensor(Shape dims, int dtype); + Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw, + int sh, int sw, int dh, int dw); + Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act); diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 17b77519..81a2db64 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -24,7 +24,28 @@ def from_onnx(model: onnx.ModelProto): data[initializer.name] = initializer for node in model.graph.node: - if node.op_type == "MatMul": + if node.op_type == "Conv": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0], + "strides": [1, 1], + }, + ) + (d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"]) + tensors[node.output[0]] = handler.conv( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + ) + elif node.op_type == "MatMul": tensors[node.output[0]] = handler.matmul( tensors[node.input[0]], tensors[node.input[1]], diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 9dfaa6e4..46328f76 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -34,6 +34,21 @@ class TestStringMethods(unittest.TestCase): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) make_and_import_model(make_graph([], "tensor", [x], [x])) + def test_conv(self): + i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 3, 4, 4]) + w = make_tensor_value_info("w", TensorProto.FLOAT, [2, 3, 3, 3]) + o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 2, 2]) + conv = make_node( + "Conv", + ["i", "w"], + ["o"], + "conv", + pads=[1, 1], + strides=[2, 1], + dilations=[1, 2], + ) + make_and_import_model(make_graph([conv], "conv", [i, w], [o])) + def test_matmul(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 3b408ff4..ac8bfb21 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/batch_norm.h" #include "operators/concat.h" +#include "operators/conv.h" #include "operators/element_wise.h" #include "operators/gather.h" #include "operators/matmul.h" @@ -19,6 +20,20 @@ Tensor GraphHandlerObj::tensor(Shape dims, int dtype) { return g->addTensor(std::move(dims), dtype_repr_convert(dtype)); } +Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph, + int pw, int sh, int sw, int dh, int dw) { + if (output) { + g->addOpWithOutputs(std::move(input), std::move(weight), + output, ph, pw, sh, sw, dh, dw); + return output; + } else { + return g + ->addOp(std::move(input), std::move(weight), output, ph, + pw, sh, sw, dh, dw) + ->getOutput(); + } +} + Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act) { if (y) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 75a5e724..b48a9e1c 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -39,6 +39,7 @@ void init_graph_builder(py::module &m) { .def(py::init()) .def("tensor", py::overload_cast(&Handler::tensor), policy::move) + .def("conv", &Handler::conv, policy::move) .def("matmul", py::overload_cast(&Handler::matmul),