feat: 前端支持 pool 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 16:26:47 +08:00
parent 62ceb78ae3
commit 341cf1f943
6 changed files with 129 additions and 6 deletions

View File

@ -44,6 +44,11 @@ class GraphHandlerObj {
Tensor scale, Tensor bias, float momentum, float eps,
bool training);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw);
Tensor avgPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw);
Tensor add(Tensor a, Tensor b, Tensor c);
Tensor sub(Tensor a, Tensor b, Tensor c);
Tensor mul(Tensor a, Tensor b, Tensor c);

View File

@ -46,6 +46,56 @@ def from_onnx(model: onnx.ModelProto):
tensors[node.output[0]] = handler.batchNorm(
input, output, mean, var, scale, bias, momentum, eps, training != 0
)
elif node.op_type == "MaxPool":
attributes = _parse_attribute(
node,
{
"kernel_shape": None,
"dilations": [1, 1],
"pads": [0, 0],
"strides": [1, 1],
},
)
(k, d, p, s) = (
attributes[name]
for name in ["kernel_shape", "dilations", "pads", "strides"]
)
tensors[node.output[0]] = handler.maxPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
k[1],
d[0],
d[1],
p[0],
p[1],
s[0],
s[1],
)
elif node.op_type == "AveragePool":
attributes = _parse_attribute(
node,
{
"kernel_shape": None,
"pads": [0, 0],
"strides": [1, 1],
},
)
(k, p, s) = (
attributes[name] for name in ["kernel_shape", "pads", "strides"]
)
tensors[node.output[0]] = handler.maxPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
k[1],
1,
1,
p[0],
p[1],
s[0],
s[1],
)
elif node.op_type == "Add":
tensors[node.output[0]] = handler.add(
tensors[node.input[0]],

View File

@ -55,9 +55,38 @@ class TestStringMethods(unittest.TestCase):
name="batchNormalization",
)
make_and_import_model(
make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y])
make_graph([batch_norm], "batchNorm", [x, scale, b, mean, var], [y])
)
def test_max_pool(self):
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
pool = make_node(
"MaxPool",
["x"],
["y"],
kernel_shape=[3, 3],
dilations=[1, 1],
pads=[0, 0],
strides=[2, 2],
name="maxPool",
)
make_and_import_model(make_graph([pool], "maxPool", [x], [y]))
def test_avg_pool(self):
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
pool = make_node(
"AveragePool",
["x"],
["y"],
kernel_shape=[3, 3],
pads=[0, 0],
strides=[2, 2],
name="avgPool",
)
make_and_import_model(make_graph([pool], "avgPool", [x], [y]))
def test_add(self):
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])

View File

@ -4,9 +4,11 @@
#include "operators/element_wise.h"
#include "operators/gather.h"
#include "operators/matmul.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h"
#include "operators/reshape.h"
#include "operators/unary.h"
namespace infini {
static DataType dtype_repr_convert(int);
@ -46,6 +48,35 @@ Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean,
}
}
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
if (output) {
g->addOpWithOutputs<MaxPoolObj>(std::move(input), output, kh, kw, dh,
dw, ph, pw, sh, sw);
return output;
} else {
return g
->addOp<MaxPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
pw, sh, sw)
->getOutput();
}
}
Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
if (output) {
g->addOpWithOutputs<AvgPoolObj>(std::move(input), output, kh, kw, dh,
dw, ph, pw, sh, sw);
return output;
} else {
return g
->addOp<AvgPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
pw, sh, sw)
->getOutput();
}
}
// see operators/element_wise.h
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \

View File

@ -50,6 +50,14 @@ void init_graph_builder(py::module &m) {
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
float, float, bool>(&Handler::batchNorm),
policy::move)
.def("maxPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::maxPool),
policy::move)
.def("avgPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::avgPool),
policy::move)
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
policy::move)
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),

View File

@ -5,12 +5,12 @@ namespace infini {
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
Tensor output, int kh, int kw, int dh, int dw, int ph,
int pw, int sh, int sw)
: OperatorObj(optype, {input}, {output}), kh(kh), kw(kw), dh(dh), dw(dw),
ph(ph), pw(pw), sh(sh), sw(sw) {
n = input->getDims()[0];
c = input->getDims()[1];
h = input->getDims()[2], w = input->getDims()[3];
: OperatorObj(optype, {input}, {output}),
kh(kh), kw(kw), dh(dh), dw(dw), ph(ph), pw(pw), sh(sh), sw(sw),
n(input->getDims()[0]), c(input->getDims()[1]), h(input->getDims()[2]),
w(input->getDims()[3]) {
IT_ASSERT(checkValid(graph));
}