forked from jiuyuan/InfiniTensor
feat: 前端支持 slice 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
f9d0076a86
commit
8fae67b4b4
|
@ -67,6 +67,9 @@ class GraphHandlerObj {
|
|||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||
const optional<vector<int>> &axes, bool keepdims);
|
||||
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
|
||||
const vector<int> &ends, const optional<vector<int>> &axes,
|
||||
const optional<vector<int>> &steps);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import onnx, backend
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Any
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
|
||||
|
@ -193,6 +193,15 @@ def from_onnx(model: onnx.ModelProto):
|
|||
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||
!= 0,
|
||||
)
|
||||
elif node.op_type == "Slice":
|
||||
tensors[node.output[0]] = handler.slice(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[2]]),
|
||||
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
|
@ -233,7 +242,9 @@ def parse_onnx(model: onnx.ModelProto):
|
|||
print(" {}".format(node.name))
|
||||
|
||||
|
||||
def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()):
|
||||
def _parse_attribute(
|
||||
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
|
||||
) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
if attr.name in attrs:
|
||||
if attr.type == onnx.AttributeProto.INT:
|
||||
|
@ -249,3 +260,12 @@ def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()):
|
|||
else:
|
||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||
return attrs
|
||||
|
||||
|
||||
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
|
||||
if tensor.data_type == onnx.TensorProto.INT32:
|
||||
return [int(i) for i in tensor.int32_data]
|
||||
elif tensor.data_type == onnx.TensorProto.INT64:
|
||||
return [int(i) for i in tensor.int64_data]
|
||||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
|
|
@ -208,6 +208,26 @@ class TestStringMethods(unittest.TestCase):
|
|||
)
|
||||
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
|
||||
|
||||
def test_slice(self):
|
||||
data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162])
|
||||
output = make_tensor_value_info("output", TensorProto.UINT32, [2, 1, 100, 96])
|
||||
starts = make_tensor_value_info("starts", TensorProto.INT64, [4])
|
||||
starts_data = make_tensor("starts", TensorProto.INT64, [4], [2, 10, 1, 5])
|
||||
ends = make_tensor_value_info("ends", TensorProto.INT64, [4])
|
||||
ends_data = make_tensor("ends", TensorProto.INT64, [4], [3, 10, 100, 100])
|
||||
slice = make_node(
|
||||
"Slice", ["data", "starts", "ends"], ["output"], name="gather"
|
||||
)
|
||||
make_and_import_model(
|
||||
make_graph(
|
||||
[slice],
|
||||
"slice",
|
||||
[data, starts, ends],
|
||||
[output],
|
||||
[starts_data, ends_data],
|
||||
)
|
||||
)
|
||||
|
||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||
def test_linear(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -162,6 +163,23 @@ Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
|
||||
const vector<int> &starts,
|
||||
const vector<int> &ends,
|
||||
const optional<vector<int>> &axes,
|
||||
const optional<vector<int>> &steps) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<SliceObj>(std::move(input), output, starts, ends,
|
||||
axes, steps);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<SliceObj>(std::move(input), output, starts, ends, axes,
|
||||
steps)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
static DataType dtype_repr_convert(int dtype) {
|
||||
switch ((OnnxDType)dtype) {
|
||||
case OnnxDType::FLOAT:
|
||||
|
|
|
@ -94,6 +94,12 @@ void init_graph_builder(py::module &m) {
|
|||
.def("reduceMean",
|
||||
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
|
||||
bool>(&Handler::reduceMean),
|
||||
policy::move)
|
||||
.def("slice",
|
||||
py::overload_cast<
|
||||
Tensor, Tensor, const vector<int> &, const vector<int> &,
|
||||
const optional<vector<int>> &, const optional<vector<int>> &>(
|
||||
&Handler::slice),
|
||||
policy::move);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue