From 8fae67b4b4d7f40d0c6760f47b88dc11ffd55556 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 14 Feb 2023 17:35:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=89=8D=E7=AB=AF=E6=94=AF=E6=8C=81=20?= =?UTF-8?q?slice=20=E5=8F=8A=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- include/core/graph_handler.h | 3 +++ pyinfinitensor/src/pyinfinitensor/onnx.py | 24 +++++++++++++++++++++-- pyinfinitensor/tests/test_onnx.py | 20 +++++++++++++++++++ src/core/graph_handler.cc | 18 +++++++++++++++++ src/ffi/ffi_infinitensor.cc | 6 ++++++ 5 files changed, 69 insertions(+), 2 deletions(-) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 4f9ca929..9ee2ed7e 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -67,6 +67,9 @@ class GraphHandlerObj { Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); Tensor reduceMean(Tensor data, Tensor reduced, const optional> &axes, bool keepdims); + Tensor slice(Tensor input, Tensor output, const vector &starts, + const vector &ends, const optional> &axes, + const optional> &steps); }; } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 32c946d0..7f60bf95 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,5 +1,5 @@ import onnx, backend -from typing import Dict +from typing import Dict, List, Any runtime = backend.cpu_runtime() @@ -193,6 +193,15 @@ def from_onnx(model: onnx.ModelProto): next((attr.i for attr in node.attribute if attr.name == "keepdims")) != 0, ) + elif node.op_type == "Slice": + tensors[node.output[0]] = handler.slice( + tensors[node.input[0]], + tensors.get(node.output[0]), + _parse_data(data[node.input[1]]), + _parse_data(data[node.input[2]]), + _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, + _parse_data(data[node.input[4]]) if len(node.input) > 4 else None, + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) @@ -233,7 +242,9 @@ def parse_onnx(model: onnx.ModelProto): print(" {}".format(node.name)) -def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()): +def _parse_attribute( + node: onnx.NodeProto, attrs: Dict[str, Any] = dict() +) -> Dict[str, Any]: for attr in node.attribute: if attr.name in attrs: if attr.type == onnx.AttributeProto.INT: @@ -249,3 +260,12 @@ def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()): else: assert False, "Unsupported Attribute Type: {}".format(attr.type) return attrs + + +def _parse_data(tensor: onnx.TensorProto) -> List[int]: + if tensor.data_type == onnx.TensorProto.INT32: + return [int(i) for i in tensor.int32_data] + elif tensor.data_type == onnx.TensorProto.INT64: + return [int(i) for i in tensor.int64_data] + else: + assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index f296e0fa..5fb375a1 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -208,6 +208,26 @@ class TestStringMethods(unittest.TestCase): ) make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced])) + def test_slice(self): + data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162]) + output = make_tensor_value_info("output", TensorProto.UINT32, [2, 1, 100, 96]) + starts = make_tensor_value_info("starts", TensorProto.INT64, [4]) + starts_data = make_tensor("starts", TensorProto.INT64, [4], [2, 10, 1, 5]) + ends = make_tensor_value_info("ends", TensorProto.INT64, [4]) + ends_data = make_tensor("ends", TensorProto.INT64, [4], [3, 10, 100, 100]) + slice = make_node( + "Slice", ["data", "starts", "ends"], ["output"], name="gather" + ) + make_and_import_model( + make_graph( + [slice], + "slice", + [data, starts, ends], + [output], + [starts_data, ends_data], + ) + ) + # see def test_linear(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index afde6690..ec6175e1 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -7,6 +7,7 @@ #include "operators/pooling.h" #include "operators/reduce_mean.h" #include "operators/reshape.h" +#include "operators/slice.h" #include "operators/unary.h" namespace infini { @@ -162,6 +163,23 @@ Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced, } } +Tensor GraphHandlerObj::slice(Tensor input, Tensor output, + const vector &starts, + const vector &ends, + const optional> &axes, + const optional> &steps) { + if (output) { + g->addOpWithOutputs(std::move(input), output, starts, ends, + axes, steps); + return output; + } else { + return g + ->addOp(std::move(input), output, starts, ends, axes, + steps) + ->getOutput(); + } +} + static DataType dtype_repr_convert(int dtype) { switch ((OnnxDType)dtype) { case OnnxDType::FLOAT: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 8d8f9ac3..ffe38b39 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -94,6 +94,12 @@ void init_graph_builder(py::module &m) { .def("reduceMean", py::overload_cast> &, bool>(&Handler::reduceMean), + policy::move) + .def("slice", + py::overload_cast< + Tensor, Tensor, const vector &, const vector &, + const optional> &, const optional> &>( + &Handler::slice), policy::move); }