From d11fb0ad5f43d3e0a96fab2c3ca0f4bb4949a416 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 14 Feb 2023 14:16:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=89=8D=E7=AB=AF=E6=94=AF=E6=8C=81=20?= =?UTF-8?q?gather=20=E5=8F=8A=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/graph_handler.h | 1 + include/operators/gather.h | 4 ++-- pyinfinitensor/src/pyinfinitensor/onnx.py | 7 +++++++ pyinfinitensor/tests/test_onnx.py | 9 +++++++++ src/core/graph_handler.cc | 15 +++++++++++++++ src/ffi/ffi_infinitensor.cc | 3 +++ src/operators/gather.cc | 6 +++--- 7 files changed, 40 insertions(+), 5 deletions(-) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 41d01c1b..b1331757 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -59,6 +59,7 @@ class GraphHandlerObj { Tensor flatten(Tensor s, Tensor y); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor concat(TensorVec inputs, Tensor output, int dim); + Tensor gather(Tensor input, Tensor indices, Tensor output, int axis); }; } // namespace infini diff --git a/include/operators/gather.h b/include/operators/gather.h index da86c502..d5d07a69 100644 --- a/include/operators/gather.h +++ b/include/operators/gather.h @@ -17,11 +17,11 @@ class GatherObj : public OperatorObj { * * @param graph The computation graph that this operator belongs to. * @param input The input tensor. - * @param index The index tensor. + * @param indices The index tensor. * @param output The output tensor. * @param axis The axis to gather on. */ - GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output, + GatherObj(GraphObj *graph, Tensor input, Tensor indices, Tensor output, int axis); OP_CLONE(GatherObj); std::string toString() const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 87c3899c..0b9f43b9 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -128,6 +128,13 @@ def from_onnx(model: onnx.ModelProto): tensors.get(node.output[0]), next((attr.i for attr in node.attribute if attr.name == "axis")), ) + elif node.op_type == "Gather": + tensors[node.output[0]] = handler.gather( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 605fb6bf..3e9810da 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -162,6 +162,15 @@ class TestStringMethods(unittest.TestCase): make_graph([concat], "concat", [input1, input2], [output]) ) + def test_gather(self): + data = make_tensor_value_info("data", TensorProto.FLOAT, [1, 3, 4, 4]) + indices = make_tensor_value_info("indices", TensorProto.FLOAT, [2, 1, 2]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 1, 2, 4, 4]) + gather = make_node( + "Gather", ["data", "indices"], ["output"], axis=1, name="gather" + ) + make_and_import_model(make_graph([gather], "gather", [data, indices], [output])) + # see def test_linear(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index f9aa095e..5461e923 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -2,6 +2,7 @@ #include "operators/batch_norm.h" #include "operators/concat.h" #include "operators/element_wise.h" +#include "operators/gather.h" #include "operators/matmul.h" #include "operators/reshape.h" #include "operators/unary.h" @@ -103,6 +104,20 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { } } +Tensor GraphHandlerObj::gather(Tensor input, Tensor indices, Tensor output, + int axis) { + if (output) { + g->addOpWithOutputs(std::move(input), std::move(indices), + output, axis); + return output; + } else { + return g + ->addOp(std::move(input), std::move(indices), output, + axis) + ->getOutput(); + } +} + static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { case OnnxDType::FLOAT: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 3ad18699..7fa127c9 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -79,6 +79,9 @@ void init_graph_builder(py::module &m) { policy::move) .def("concat", py::overload_cast(&Handler::concat), + policy::move) + .def("gather", + py::overload_cast(&Handler::gather), policy::move); } diff --git a/src/operators/gather.cc b/src/operators/gather.cc index a5bf9d1c..2c1cd57f 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -1,9 +1,9 @@ #include "operators/gather.h" namespace infini { -GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output, - int axis) - : OperatorObj(OpType::Gather, {input, index}, {output}), axis(axis) { +GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices, + Tensor output, int axis) + : OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) { IT_ASSERT(checkValid(graph)); }