feat: 前端支持 pool 及单元测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 16:26:47 +08:00
parent 62ceb78ae3
commit 341cf1f943
6 changed files with 129 additions and 6 deletions

View File

@ -44,6 +44,11 @@ class GraphHandlerObj {
Tensor scale, Tensor bias, float momentum, float eps, Tensor scale, Tensor bias, float momentum, float eps,
bool training); bool training);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw);
Tensor avgPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw);
Tensor add(Tensor a, Tensor b, Tensor c); Tensor add(Tensor a, Tensor b, Tensor c);
Tensor sub(Tensor a, Tensor b, Tensor c); Tensor sub(Tensor a, Tensor b, Tensor c);
Tensor mul(Tensor a, Tensor b, Tensor c); Tensor mul(Tensor a, Tensor b, Tensor c);

View File

@ -46,6 +46,56 @@ def from_onnx(model: onnx.ModelProto):
tensors[node.output[0]] = handler.batchNorm( tensors[node.output[0]] = handler.batchNorm(
input, output, mean, var, scale, bias, momentum, eps, training != 0 input, output, mean, var, scale, bias, momentum, eps, training != 0
) )
elif node.op_type == "MaxPool":
attributes = _parse_attribute(
node,
{
"kernel_shape": None,
"dilations": [1, 1],
"pads": [0, 0],
"strides": [1, 1],
},
)
(k, d, p, s) = (
attributes[name]
for name in ["kernel_shape", "dilations", "pads", "strides"]
)
tensors[node.output[0]] = handler.maxPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
k[1],
d[0],
d[1],
p[0],
p[1],
s[0],
s[1],
)
elif node.op_type == "AveragePool":
attributes = _parse_attribute(
node,
{
"kernel_shape": None,
"pads": [0, 0],
"strides": [1, 1],
},
)
(k, p, s) = (
attributes[name] for name in ["kernel_shape", "pads", "strides"]
)
tensors[node.output[0]] = handler.maxPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
k[1],
1,
1,
p[0],
p[1],
s[0],
s[1],
)
elif node.op_type == "Add": elif node.op_type == "Add":
tensors[node.output[0]] = handler.add( tensors[node.output[0]] = handler.add(
tensors[node.input[0]], tensors[node.input[0]],

View File

@ -55,9 +55,38 @@ class TestStringMethods(unittest.TestCase):
name="batchNormalization", name="batchNormalization",
) )
make_and_import_model( make_and_import_model(
make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y]) make_graph([batch_norm], "batchNorm", [x, scale, b, mean, var], [y])
) )
def test_max_pool(self):
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
pool = make_node(
"MaxPool",
["x"],
["y"],
kernel_shape=[3, 3],
dilations=[1, 1],
pads=[0, 0],
strides=[2, 2],
name="maxPool",
)
make_and_import_model(make_graph([pool], "maxPool", [x], [y]))
def test_avg_pool(self):
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
pool = make_node(
"AveragePool",
["x"],
["y"],
kernel_shape=[3, 3],
pads=[0, 0],
strides=[2, 2],
name="avgPool",
)
make_and_import_model(make_graph([pool], "avgPool", [x], [y]))
def test_add(self): def test_add(self):
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7]) b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])

View File

@ -4,9 +4,11 @@
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h" #include "operators/reduce_mean.h"
#include "operators/reshape.h" #include "operators/reshape.h"
#include "operators/unary.h" #include "operators/unary.h"
namespace infini { namespace infini {
static DataType dtype_repr_convert(int); static DataType dtype_repr_convert(int);
@ -46,6 +48,35 @@ Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean,
} }
} }
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
if (output) {
g->addOpWithOutputs<MaxPoolObj>(std::move(input), output, kh, kw, dh,
dw, ph, pw, sh, sw);
return output;
} else {
return g
->addOp<MaxPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
pw, sh, sw)
->getOutput();
}
}
Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
if (output) {
g->addOpWithOutputs<AvgPoolObj>(std::move(input), output, kh, kw, dh,
dw, ph, pw, sh, sw);
return output;
} else {
return g
->addOp<AvgPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
pw, sh, sw)
->getOutput();
}
}
// see operators/element_wise.h // see operators/element_wise.h
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \ #define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \ Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \

View File

@ -50,6 +50,14 @@ void init_graph_builder(py::module &m) {
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
float, float, bool>(&Handler::batchNorm), float, float, bool>(&Handler::batchNorm),
policy::move) policy::move)
.def("maxPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::maxPool),
policy::move)
.def("avgPool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&Handler::avgPool),
policy::move)
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add), .def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
policy::move) policy::move)
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub), .def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),

View File

@ -5,12 +5,12 @@ namespace infini {
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input, PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
Tensor output, int kh, int kw, int dh, int dw, int ph, Tensor output, int kh, int kw, int dh, int dw, int ph,
int pw, int sh, int sw) int pw, int sh, int sw)
: OperatorObj(optype, {input}, {output}), kh(kh), kw(kw), dh(dh), dw(dw), : OperatorObj(optype, {input}, {output}),
ph(ph), pw(pw), sh(sh), sw(sw) {
n = input->getDims()[0];
c = input->getDims()[1];
h = input->getDims()[2], w = input->getDims()[3];
kh(kh), kw(kw), dh(dh), dw(dw), ph(ph), pw(pw), sh(sh), sw(sw),
n(input->getDims()[0]), c(input->getDims()[1]), h(input->getDims()[2]),
w(input->getDims()[3]) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }