forked from jiuyuan/InfiniTensor
feat: 前端支持 slice 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
f9d0076a86
commit
8fae67b4b4
|
@ -67,6 +67,9 @@ class GraphHandlerObj {
|
||||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
Tensor reduceMean(Tensor data, Tensor reduced,
|
Tensor reduceMean(Tensor data, Tensor reduced,
|
||||||
const optional<vector<int>> &axes, bool keepdims);
|
const optional<vector<int>> &axes, bool keepdims);
|
||||||
|
Tensor slice(Tensor input, Tensor output, const vector<int> &starts,
|
||||||
|
const vector<int> &ends, const optional<vector<int>> &axes,
|
||||||
|
const optional<vector<int>> &steps);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import onnx, backend
|
import onnx, backend
|
||||||
from typing import Dict
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
|
||||||
|
@ -193,6 +193,15 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
|
||||||
!= 0,
|
!= 0,
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Slice":
|
||||||
|
tensors[node.output[0]] = handler.slice(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
_parse_data(data[node.input[1]]),
|
||||||
|
_parse_data(data[node.input[2]]),
|
||||||
|
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||||
|
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
@ -233,7 +242,9 @@ def parse_onnx(model: onnx.ModelProto):
|
||||||
print(" {}".format(node.name))
|
print(" {}".format(node.name))
|
||||||
|
|
||||||
|
|
||||||
def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()):
|
def _parse_attribute(
|
||||||
|
node: onnx.NodeProto, attrs: Dict[str, Any] = dict()
|
||||||
|
) -> Dict[str, Any]:
|
||||||
for attr in node.attribute:
|
for attr in node.attribute:
|
||||||
if attr.name in attrs:
|
if attr.name in attrs:
|
||||||
if attr.type == onnx.AttributeProto.INT:
|
if attr.type == onnx.AttributeProto.INT:
|
||||||
|
@ -249,3 +260,12 @@ def _parse_attribute(node: onnx.NodeProto, attrs: dict = dict()):
|
||||||
else:
|
else:
|
||||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_data(tensor: onnx.TensorProto) -> List[int]:
|
||||||
|
if tensor.data_type == onnx.TensorProto.INT32:
|
||||||
|
return [int(i) for i in tensor.int32_data]
|
||||||
|
elif tensor.data_type == onnx.TensorProto.INT64:
|
||||||
|
return [int(i) for i in tensor.int64_data]
|
||||||
|
else:
|
||||||
|
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||||
|
|
|
@ -208,6 +208,26 @@ class TestStringMethods(unittest.TestCase):
|
||||||
)
|
)
|
||||||
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
|
make_and_import_model(make_graph([reduceMean], "reduceMean", [data], [reduced]))
|
||||||
|
|
||||||
|
def test_slice(self):
|
||||||
|
data = make_tensor_value_info("data", TensorProto.UINT32, [10, 64, 162, 162])
|
||||||
|
output = make_tensor_value_info("output", TensorProto.UINT32, [2, 1, 100, 96])
|
||||||
|
starts = make_tensor_value_info("starts", TensorProto.INT64, [4])
|
||||||
|
starts_data = make_tensor("starts", TensorProto.INT64, [4], [2, 10, 1, 5])
|
||||||
|
ends = make_tensor_value_info("ends", TensorProto.INT64, [4])
|
||||||
|
ends_data = make_tensor("ends", TensorProto.INT64, [4], [3, 10, 100, 100])
|
||||||
|
slice = make_node(
|
||||||
|
"Slice", ["data", "starts", "ends"], ["output"], name="gather"
|
||||||
|
)
|
||||||
|
make_and_import_model(
|
||||||
|
make_graph(
|
||||||
|
[slice],
|
||||||
|
"slice",
|
||||||
|
[data, starts, ends],
|
||||||
|
[output],
|
||||||
|
[starts_data, ends_data],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
#include "operators/pooling.h"
|
#include "operators/pooling.h"
|
||||||
#include "operators/reduce_mean.h"
|
#include "operators/reduce_mean.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
|
#include "operators/slice.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
@ -162,6 +163,23 @@ Tensor GraphHandlerObj::reduceMean(Tensor data, Tensor reduced,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
|
||||||
|
const vector<int> &starts,
|
||||||
|
const vector<int> &ends,
|
||||||
|
const optional<vector<int>> &axes,
|
||||||
|
const optional<vector<int>> &steps) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<SliceObj>(std::move(input), output, starts, ends,
|
||||||
|
axes, steps);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<SliceObj>(std::move(input), output, starts, ends, axes,
|
||||||
|
steps)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int dtype) {
|
static DataType dtype_repr_convert(int dtype) {
|
||||||
switch ((OnnxDType)dtype) {
|
switch ((OnnxDType)dtype) {
|
||||||
case OnnxDType::FLOAT:
|
case OnnxDType::FLOAT:
|
||||||
|
|
|
@ -94,6 +94,12 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("reduceMean",
|
.def("reduceMean",
|
||||||
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
|
py::overload_cast<Tensor, Tensor, const optional<vector<int>> &,
|
||||||
bool>(&Handler::reduceMean),
|
bool>(&Handler::reduceMean),
|
||||||
|
policy::move)
|
||||||
|
.def("slice",
|
||||||
|
py::overload_cast<
|
||||||
|
Tensor, Tensor, const vector<int> &, const vector<int> &,
|
||||||
|
const optional<vector<int>> &, const optional<vector<int>> &>(
|
||||||
|
&Handler::slice),
|
||||||
policy::move);
|
policy::move);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue