feat: 前端支持 Conv 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-22 15:05:44 +08:00
parent ce04177585
commit 4c7fdf44c5
5 changed files with 56 additions and 1 deletions

View File

@ -38,6 +38,9 @@ class GraphHandlerObj {
Tensor tensor(Shape dims, int dtype);
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
int sh, int sw, int dh, int dw);
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
Tensor bias, ActType act);

View File

@ -24,7 +24,28 @@ def from_onnx(model: onnx.ModelProto):
data[initializer.name] = initializer
for node in model.graph.node:
if node.op_type == "MatMul":
if node.op_type == "Conv":
attributes = _parse_attribute(
node,
{
"dilations": [1, 1],
"pads": [0, 0],
"strides": [1, 1],
},
)
(d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"])
tensors[node.output[0]] = handler.conv(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],

View File

@ -34,6 +34,21 @@ class TestStringMethods(unittest.TestCase):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
make_and_import_model(make_graph([], "tensor", [x], [x]))
def test_conv(self):
i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 3, 4, 4])
w = make_tensor_value_info("w", TensorProto.FLOAT, [2, 3, 3, 3])
o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 2, 2])
conv = make_node(
"Conv",
["i", "w"],
["o"],
"conv",
pads=[1, 1],
strides=[2, 1],
dilations=[1, 2],
)
make_and_import_model(make_graph([conv], "conv", [i, w], [o]))
def test_matmul(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
#include "operators/conv.h"
#include "operators/element_wise.h"
#include "operators/gather.h"
#include "operators/matmul.h"
@ -19,6 +20,20 @@ Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
}
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph,
int pw, int sh, int sw, int dh, int dw) {
if (output) {
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
output, ph, pw, sh, sw, dh, dw);
return output;
} else {
return g
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
pw, sh, sw, dh, dw)
->getOutput();
}
}
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
bool transB, Tensor bias, ActType act) {
if (y) {

View File

@ -39,6 +39,7 @@ void init_graph_builder(py::module &m) {
.def(py::init<Runtime>())
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
policy::move)
.def("conv", &Handler::conv, policy::move)
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&Handler::matmul),