forked from jiuyuan/InfiniTensor
feat: 前端支持 Conv 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
ce04177585
commit
4c7fdf44c5
|
@ -38,6 +38,9 @@ class GraphHandlerObj {
|
||||||
|
|
||||||
Tensor tensor(Shape dims, int dtype);
|
Tensor tensor(Shape dims, int dtype);
|
||||||
|
|
||||||
|
Tensor conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
|
||||||
|
int sh, int sw, int dh, int dw);
|
||||||
|
|
||||||
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB,
|
||||||
Tensor bias, ActType act);
|
Tensor bias, ActType act);
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,28 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
data[initializer.name] = initializer
|
data[initializer.name] = initializer
|
||||||
|
|
||||||
for node in model.graph.node:
|
for node in model.graph.node:
|
||||||
if node.op_type == "MatMul":
|
if node.op_type == "Conv":
|
||||||
|
attributes = _parse_attribute(
|
||||||
|
node,
|
||||||
|
{
|
||||||
|
"dilations": [1, 1],
|
||||||
|
"pads": [0, 0],
|
||||||
|
"strides": [1, 1],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
(d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"])
|
||||||
|
tensors[node.output[0]] = handler.conv(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
p[0],
|
||||||
|
p[1],
|
||||||
|
s[0],
|
||||||
|
s[1],
|
||||||
|
d[0],
|
||||||
|
d[1],
|
||||||
|
)
|
||||||
|
elif node.op_type == "MatMul":
|
||||||
tensors[node.output[0]] = handler.matmul(
|
tensors[node.output[0]] = handler.matmul(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
|
|
|
@ -34,6 +34,21 @@ class TestStringMethods(unittest.TestCase):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
make_and_import_model(make_graph([], "tensor", [x], [x]))
|
make_and_import_model(make_graph([], "tensor", [x], [x]))
|
||||||
|
|
||||||
|
def test_conv(self):
|
||||||
|
i = make_tensor_value_info("i", TensorProto.FLOAT, [1, 3, 4, 4])
|
||||||
|
w = make_tensor_value_info("w", TensorProto.FLOAT, [2, 3, 3, 3])
|
||||||
|
o = make_tensor_value_info("o", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||||
|
conv = make_node(
|
||||||
|
"Conv",
|
||||||
|
["i", "w"],
|
||||||
|
["o"],
|
||||||
|
"conv",
|
||||||
|
pads=[1, 1],
|
||||||
|
strides=[2, 1],
|
||||||
|
dilations=[1, 2],
|
||||||
|
)
|
||||||
|
make_and_import_model(make_graph([conv], "conv", [i, w], [o]))
|
||||||
|
|
||||||
def test_matmul(self):
|
def test_matmul(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
|
#include "operators/conv.h"
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
@ -19,6 +20,20 @@ Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
|
||||||
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph,
|
||||||
|
int pw, int sh, int sw, int dh, int dw) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
|
||||||
|
output, ph, pw, sh, sw, dh, dw);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
|
||||||
|
pw, sh, sw, dh, dw)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
|
||||||
bool transB, Tensor bias, ActType act) {
|
bool transB, Tensor bias, ActType act) {
|
||||||
if (y) {
|
if (y) {
|
||||||
|
|
|
@ -39,6 +39,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def(py::init<Runtime>())
|
.def(py::init<Runtime>())
|
||||||
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
.def("tensor", py::overload_cast<Shape, int>(&Handler::tensor),
|
||||||
policy::move)
|
policy::move)
|
||||||
|
.def("conv", &Handler::conv, policy::move)
|
||||||
.def("matmul",
|
.def("matmul",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||||
ActType>(&Handler::matmul),
|
ActType>(&Handler::matmul),
|
||||||
|
|
Loading…
Reference in New Issue