forked from jiuyuan/InfiniTensor
feat: 前端支持 pool 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
62ceb78ae3
commit
341cf1f943
|
@ -44,6 +44,11 @@ class GraphHandlerObj {
|
|||
Tensor scale, Tensor bias, float momentum, float eps,
|
||||
bool training);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw);
|
||||
Tensor avgPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw);
|
||||
|
||||
Tensor add(Tensor a, Tensor b, Tensor c);
|
||||
Tensor sub(Tensor a, Tensor b, Tensor c);
|
||||
Tensor mul(Tensor a, Tensor b, Tensor c);
|
||||
|
|
|
@ -46,6 +46,56 @@ def from_onnx(model: onnx.ModelProto):
|
|||
tensors[node.output[0]] = handler.batchNorm(
|
||||
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"kernel_shape": None,
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(k, d, p, s) = (
|
||||
attributes[name]
|
||||
for name in ["kernel_shape", "dilations", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.maxPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
k[1],
|
||||
d[0],
|
||||
d[1],
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
)
|
||||
elif node.op_type == "AveragePool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
{
|
||||
"kernel_shape": None,
|
||||
"pads": [0, 0],
|
||||
"strides": [1, 1],
|
||||
},
|
||||
)
|
||||
(k, p, s) = (
|
||||
attributes[name] for name in ["kernel_shape", "pads", "strides"]
|
||||
)
|
||||
tensors[node.output[0]] = handler.maxPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
k[0],
|
||||
k[1],
|
||||
1,
|
||||
1,
|
||||
p[0],
|
||||
p[1],
|
||||
s[0],
|
||||
s[1],
|
||||
)
|
||||
elif node.op_type == "Add":
|
||||
tensors[node.output[0]] = handler.add(
|
||||
tensors[node.input[0]],
|
||||
|
|
|
@ -55,9 +55,38 @@ class TestStringMethods(unittest.TestCase):
|
|||
name="batchNormalization",
|
||||
)
|
||||
make_and_import_model(
|
||||
make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y])
|
||||
make_graph([batch_norm], "batchNorm", [x, scale, b, mean, var], [y])
|
||||
)
|
||||
|
||||
def test_max_pool(self):
|
||||
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
|
||||
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
|
||||
pool = make_node(
|
||||
"MaxPool",
|
||||
["x"],
|
||||
["y"],
|
||||
kernel_shape=[3, 3],
|
||||
dilations=[1, 1],
|
||||
pads=[0, 0],
|
||||
strides=[2, 2],
|
||||
name="maxPool",
|
||||
)
|
||||
make_and_import_model(make_graph([pool], "maxPool", [x], [y]))
|
||||
|
||||
def test_avg_pool(self):
|
||||
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
|
||||
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
|
||||
pool = make_node(
|
||||
"AveragePool",
|
||||
["x"],
|
||||
["y"],
|
||||
kernel_shape=[3, 3],
|
||||
pads=[0, 0],
|
||||
strides=[2, 2],
|
||||
name="avgPool",
|
||||
)
|
||||
make_and_import_model(make_graph([pool], "avgPool", [x], [y]))
|
||||
|
||||
def test_add(self):
|
||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
|
|
@ -4,9 +4,11 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
static DataType dtype_repr_convert(int);
|
||||
|
@ -46,6 +48,35 @@ Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh,
|
||||
int sw) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<MaxPoolObj>(std::move(input), output, kh, kw, dh,
|
||||
dw, ph, pw, sh, sw);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<MaxPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
||||
pw, sh, sw)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh,
|
||||
int sw) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<AvgPoolObj>(std::move(input), output, kh, kw, dh,
|
||||
dw, ph, pw, sh, sw);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<AvgPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
||||
pw, sh, sw)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
// see operators/element_wise.h
|
||||
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
||||
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
||||
|
|
|
@ -50,6 +50,14 @@ void init_graph_builder(py::module &m) {
|
|||
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
|
||||
float, float, bool>(&Handler::batchNorm),
|
||||
policy::move)
|
||||
.def("maxPool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&Handler::maxPool),
|
||||
policy::move)
|
||||
.def("avgPool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&Handler::avgPool),
|
||||
policy::move)
|
||||
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
||||
policy::move)
|
||||
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
||||
|
|
|
@ -5,12 +5,12 @@ namespace infini {
|
|||
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
||||
Tensor output, int kh, int kw, int dh, int dw, int ph,
|
||||
int pw, int sh, int sw)
|
||||
: OperatorObj(optype, {input}, {output}), kh(kh), kw(kw), dh(dh), dw(dw),
|
||||
ph(ph), pw(pw), sh(sh), sw(sw) {
|
||||
n = input->getDims()[0];
|
||||
c = input->getDims()[1];
|
||||
h = input->getDims()[2], w = input->getDims()[3];
|
||||
: OperatorObj(optype, {input}, {output}),
|
||||
|
||||
kh(kh), kw(kw), dh(dh), dw(dw), ph(ph), pw(pw), sh(sh), sw(sw),
|
||||
|
||||
n(input->getDims()[0]), c(input->getDims()[1]), h(input->getDims()[2]),
|
||||
w(input->getDims()[3]) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue