From 62ceb78ae3f2287fb8f776850b47ac4781cd2aa8 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 14 Feb 2023 15:35:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=89=8D=E7=AB=AF=E6=94=AF=E6=8C=81=20?= =?UTF-8?q?reduceMean=20=E5=8F=8A=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/graph_handler.h | 4 +++- include/operators/reduce_mean.h | 3 +-- pyinfinitensor/src/pyinfinitensor/onnx.py | 8 ++++++++ pyinfinitensor/tests/test_onnx.py | 8 ++++++++ src/core/graph_handler.cc | 21 +++++++++++++++++---- src/ffi/ffi_infinitensor.cc | 4 ++++ src/operators/reduce_mean.cc | 3 +-- 7 files changed, 42 insertions(+), 9 deletions(-) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index b1331757..7cc9351d 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -59,7 +59,9 @@ class GraphHandlerObj { Tensor flatten(Tensor s, Tensor y); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor concat(TensorVec inputs, Tensor output, int dim); - Tensor gather(Tensor input, Tensor indices, Tensor output, int axis); + Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); + Tensor reduceMean(Tensor data, Tensor reduced, + const optional> &axes, bool keepdims); }; } // namespace infini diff --git a/include/operators/reduce_mean.h b/include/operators/reduce_mean.h index 23ea1432..76d3454b 100644 --- a/include/operators/reduce_mean.h +++ b/include/operators/reduce_mean.h @@ -21,8 +21,7 @@ class ReduceMeanObj : public OperatorObj { * @param keepDims Keep the reduced dimensions or not. */ ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, - const optional> &axes, - bool keepDims = true); + const optional> &axes, bool keepDims = true); OP_CLONE(ReduceMeanObj); optional> inferShape(const TensorVec &inputs) const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 0b9f43b9..6f72dc94 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -135,6 +135,14 @@ def from_onnx(model: onnx.ModelProto): tensors.get(node.output[0]), next((attr.i for attr in node.attribute if attr.name == "axis")), ) + elif node.op_type == "ReduceMean": + tensors[node.output[0]] = handler.reduceMean( + tensors[node.input[0]], + tensors.get(node.output[0]), + tensors[node.input[1]] if len(node.input) > 1 else None, + next((attr.i for attr in node.attribute if attr.name == "keepdims")) + != 0, + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 3e9810da..804540d7 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -171,6 +171,14 @@ class TestStringMethods(unittest.TestCase): ) make_and_import_model(make_graph([gather], "gather", [data, indices], [output])) + def test_reduce_mean(self): + data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 3, 4]) + reduced = make_tensor_value_info("reduced", TensorProto.FLOAT, [1, 1, 1, 1]) + reduceMean = make_node( + "ReduceMean", ["data"], ["reduced"], keepdims=1, name="reduceMean" + ) + make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced])) + # see def test_linear(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 5461e923..3840ca16 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -4,9 +4,9 @@ #include "operators/element_wise.h" #include "operators/gather.h" #include "operators/matmul.h" +#include "operators/reduce_mean.h" #include "operators/reshape.h" #include "operators/unary.h" - namespace infini { static DataType dtype_repr_convert(int); @@ -104,20 +104,33 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { } } -Tensor GraphHandlerObj::gather(Tensor input, Tensor indices, Tensor output, +Tensor GraphHandlerObj::gather(Tensor data, Tensor indices, Tensor output, int axis) { if (output) { - g->addOpWithOutputs(std::move(input), std::move(indices), + g->addOpWithOutputs(std::move(data), std::move(indices), output, axis); return output; } else { return g - ->addOp(std::move(input), std::move(indices), output, + ->addOp(std::move(data), std::move(indices), output, axis) ->getOutput(); } } +Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced, + const optional> &axes, + bool keepdims) { + if (reduced) { + g->addOpWithOutputs(std::move(data), reduced, axes, + keepdims); + return reduced; + } else { + return g->addOp(std::move(data), reduced, axes, keepdims) + ->getOutput(); + } +} + static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { case OnnxDType::FLOAT: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 7fa127c9..d1998d6c 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -82,6 +82,10 @@ void init_graph_builder(py::module &m) { policy::move) .def("gather", py::overload_cast(&Handler::gather), + policy::move) + .def("reduceMean", + py::overload_cast> &, + bool>(&Handler::reduceMean), policy::move); } diff --git a/src/operators/reduce_mean.cc b/src/operators/reduce_mean.cc index b59cc828..3a562abb 100644 --- a/src/operators/reduce_mean.cc +++ b/src/operators/reduce_mean.cc @@ -2,8 +2,7 @@ namespace infini { ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, - const optional> &_axes, - bool keepDims) + const optional> &_axes, bool keepDims) : OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) { const auto size = input->getDims().size(); if (_axes) {