feat: 前端支持 slice 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 17:35:18 +08:00
parent f9d0076a86
commit 8fae67b4b4
5 changed files with 69 additions and 2 deletions

View File

@ -67,6 +67,9 @@ class GraphHandlerObj {
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
Tensor reduceMean(Tensor data, Tensor reduced,
const optional<vector<int>> &axes, bool keepdims);
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
const vector<int> &ends, const optional<vector<int>> &axes,
const optional<vector<int>> &steps);
};
} // namespace infini

View File

@ -1,5 +1,5 @@
import onnx, backend
from typing import Dict
from typing import Dict, List, Any
runtime = backend.cpu_runtime()
@ -193,6 +193,15 @@ def from_onnx(model: onnx.ModelProto):
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
!= 0,
)
elif node.op_type == "Slice":
tensors[node.output[0]] = handler.slice(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[2]]),
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
@ -233,7 +242,9 @@ def parse_onnx(model: onnx.ModelProto):
print(" {}".format(node.name))
def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()):
def _parse_attribute(
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
) -> Dict[str, Any]:
for attr in node.attribute:
if attr.name in attrs:
if attr.type == onnx.AttributeProto.INT:
@ -249,3 +260,12 @@ def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()):
else:
assert False, "Unsupported Attribute Type: {}".format(attr.type)
return attrs
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
if tensor.data_type == onnx.TensorProto.INT32:
return [int(i) for i in tensor.int32_data]
elif tensor.data_type == onnx.TensorProto.INT64:
return [int(i) for i in tensor.int64_data]
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)

View File

@ -208,6 +208,26 @@ class TestStringMethods(unittest.TestCase):
)
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
def test_slice(self):
data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162])
output = make_tensor_value_info("output", TensorProto.UINT32, [2, 1, 100, 96])
starts = make_tensor_value_info("starts", TensorProto.INT64, [4])
starts_data = make_tensor("starts", TensorProto.INT64, [4], [2, 10, 1, 5])
ends = make_tensor_value_info("ends", TensorProto.INT64, [4])
ends_data = make_tensor("ends", TensorProto.INT64, [4], [3, 10, 100, 100])
slice = make_node(
"Slice", ["data", "starts", "ends"], ["output"], name="gather"
)
make_and_import_model(
make_graph(
[slice],
"slice",
[data, starts, ends],
[output],
[starts_data, ends_data],
)
)
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
def test_linear(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])

View File

@ -7,6 +7,7 @@
#include "operators/pooling.h"
#include "operators/reduce_mean.h"
#include "operators/reshape.h"
#include "operators/slice.h"
#include "operators/unary.h"
namespace infini {
@ -162,6 +163,23 @@ Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
}
}
Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
const vector<int> &starts,
const vector<int> &ends,
const optional<vector<int>> &axes,
const optional<vector<int>> &steps) {
if (output) {
g->addOpWithOutputs<SliceObj>(std::move(input), output, starts, ends,
axes, steps);
return output;
} else {
return g
->addOp<SliceObj>(std::move(input), output, starts, ends, axes,
steps)
->getOutput();
}
}
static DataType dtype_repr_convert(int dtype) {
switch ((OnnxDType)dtype) {
case OnnxDType::FLOAT:

View File

@ -94,6 +94,12 @@ void init_graph_builder(py::module &m) {
.def("reduceMean",
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
bool>(&Handler::reduceMean),
policy::move)
.def("slice",
py::overload_cast<
Tensor, Tensor, const vector<int> &, const vector<int> &,
const optional<vector<int>> &, const optional<vector<int>> &>(
&Handler::slice),
policy::move);
}