forked from jiuyuan/InfiniTensor
feat: 前端支持 Conv 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
ce04177585
commit
4c7fdf44c5
|
@ -38,6 +38,9 @@ class GraphHandlerObj {
|
|||
|
||||
Tensor tensor(Shape dims, int dtype);
|
||||
|
||||
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw);
|
||||
|
||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||
Tensor bias, ActType act);
|
||||
|
||||
|
|
|
@ -24,7 +24,28 @@ def from_onnx(model: onnx.ModelProto):
|
|||
data[initializer.name] = initializer
|
||||
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "MatMul":
|
||||
if node.op_type == "Conv":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"])
|
||||
tensors[node.output[0]] = handler.conv(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
d[0],
|
||||
d[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
|
|
|
@ -34,6 +34,21 @@ class TestStringMethods(unittest.TestCase):
|
|||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
make_and_import_model(make_graph([], "tensor", [x], [x]))
|
||||
|
||||
def test_conv(self):
|
||||
i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 3, 4, 4])
|
||||
w = make_tensor_value_info("w", TensorProto.FLOAT, [2, 3, 3, 3])
|
||||
o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||
conv = make_node(
|
||||
"Conv",
|
||||
["i", "w"],
|
||||
["o"],
|
||||
"conv",
|
||||
pads=[1, 1],
|
||||
strides=[2, 1],
|
||||
dilations=[1, 2],
|
||||
)
|
||||
make_and_import_model(make_graph([conv], "conv", [i, w], [o]))
|
||||
|
||||
def test_matmul(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
|
@ -19,6 +20,20 @@ Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
|
|||
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh, int sw, int dh, int dw) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
|
||||
output, ph, pw, sh, sw, dh, dw);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
|
||||
pw, sh, sw, dh, dw)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||
bool transB, Tensor bias, ActType act) {
|
||||
if (y) {
|
||||
|
|
|
@ -39,6 +39,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def(py::init<Runtime>())
|
||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||
policy::move)
|
||||
.def("conv", &Handler::conv, policy::move)
|
||||
.def("matmul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||
ActType>(&Handler::matmul),
|
||||
|
|
Loading…
Reference in New Issue