forked from jiuyuan/InfiniTensor
feat: 前端支持 pool 及单元测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
62ceb78ae3
commit
341cf1f943
|
@ -44,6 +44,11 @@ class GraphHandlerObj {
|
||||||
Tensor scale, Tensor bias, float momentum, float eps,
|
Tensor scale, Tensor bias, float momentum, float eps,
|
||||||
bool training);
|
bool training);
|
||||||
|
|
||||||
|
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||||
|
int ph, int pw, int sh, int sw);
|
||||||
|
Tensor avgPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||||
|
int ph, int pw, int sh, int sw);
|
||||||
|
|
||||||
Tensor add(Tensor a, Tensor b, Tensor c);
|
Tensor add(Tensor a, Tensor b, Tensor c);
|
||||||
Tensor sub(Tensor a, Tensor b, Tensor c);
|
Tensor sub(Tensor a, Tensor b, Tensor c);
|
||||||
Tensor mul(Tensor a, Tensor b, Tensor c);
|
Tensor mul(Tensor a, Tensor b, Tensor c);
|
||||||
|
|
|
@ -46,6 +46,56 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors[node.output[0]] = handler.batchNorm(
|
tensors[node.output[0]] = handler.batchNorm(
|
||||||
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
input, output, mean, var, scale, bias, momentum, eps, training != 0
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "MaxPool":
|
||||||
|
attributes = _parse_attribute(
|
||||||
|
node,
|
||||||
|
{
|
||||||
|
"kernel_shape": None,
|
||||||
|
"dilations": [1, 1],
|
||||||
|
"pads": [0, 0],
|
||||||
|
"strides": [1, 1],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
(k, d, p, s) = (
|
||||||
|
attributes[name]
|
||||||
|
for name in ["kernel_shape", "dilations", "pads", "strides"]
|
||||||
|
)
|
||||||
|
tensors[node.output[0]] = handler.maxPool(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
k[0],
|
||||||
|
k[1],
|
||||||
|
d[0],
|
||||||
|
d[1],
|
||||||
|
p[0],
|
||||||
|
p[1],
|
||||||
|
s[0],
|
||||||
|
s[1],
|
||||||
|
)
|
||||||
|
elif node.op_type == "AveragePool":
|
||||||
|
attributes = _parse_attribute(
|
||||||
|
node,
|
||||||
|
{
|
||||||
|
"kernel_shape": None,
|
||||||
|
"pads": [0, 0],
|
||||||
|
"strides": [1, 1],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
(k, p, s) = (
|
||||||
|
attributes[name] for name in ["kernel_shape", "pads", "strides"]
|
||||||
|
)
|
||||||
|
tensors[node.output[0]] = handler.maxPool(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
k[0],
|
||||||
|
k[1],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
p[0],
|
||||||
|
p[1],
|
||||||
|
s[0],
|
||||||
|
s[1],
|
||||||
|
)
|
||||||
elif node.op_type == "Add":
|
elif node.op_type == "Add":
|
||||||
tensors[node.output[0]] = handler.add(
|
tensors[node.output[0]] = handler.add(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
|
|
@ -55,9 +55,38 @@ class TestStringMethods(unittest.TestCase):
|
||||||
name="batchNormalization",
|
name="batchNormalization",
|
||||||
)
|
)
|
||||||
make_and_import_model(
|
make_and_import_model(
|
||||||
make_graph([batch_norm], "batch_norm", [x, scale, b, mean, var], [y])
|
make_graph([batch_norm], "batchNorm", [x, scale, b, mean, var], [y])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_max_pool(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
|
||||||
|
pool = make_node(
|
||||||
|
"MaxPool",
|
||||||
|
["x"],
|
||||||
|
["y"],
|
||||||
|
kernel_shape=[3, 3],
|
||||||
|
dilations=[1, 1],
|
||||||
|
pads=[0, 0],
|
||||||
|
strides=[2, 2],
|
||||||
|
name="maxPool",
|
||||||
|
)
|
||||||
|
make_and_import_model(make_graph([pool], "maxPool", [x], [y]))
|
||||||
|
|
||||||
|
def test_avg_pool(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 64, 162, 162])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 64, 80, 80])
|
||||||
|
pool = make_node(
|
||||||
|
"AveragePool",
|
||||||
|
["x"],
|
||||||
|
["y"],
|
||||||
|
kernel_shape=[3, 3],
|
||||||
|
pads=[0, 0],
|
||||||
|
strides=[2, 2],
|
||||||
|
name="avgPool",
|
||||||
|
)
|
||||||
|
make_and_import_model(make_graph([pool], "avgPool", [x], [y]))
|
||||||
|
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
|
|
@ -4,9 +4,11 @@
|
||||||
#include "operators/element_wise.h"
|
#include "operators/element_wise.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
|
#include "operators/pooling.h"
|
||||||
#include "operators/reduce_mean.h"
|
#include "operators/reduce_mean.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
static DataType dtype_repr_convert(int);
|
static DataType dtype_repr_convert(int);
|
||||||
|
@ -46,6 +48,35 @@ Tensor GraphHandlerObj::batchNorm(Tensor input, Tensor output, Tensor mean,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||||
|
int dh, int dw, int ph, int pw, int sh,
|
||||||
|
int sw) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<MaxPoolObj>(std::move(input), output, kh, kw, dh,
|
||||||
|
dw, ph, pw, sh, sw);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<MaxPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
||||||
|
pw, sh, sw)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
|
||||||
|
int dh, int dw, int ph, int pw, int sh,
|
||||||
|
int sw) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<AvgPoolObj>(std::move(input), output, kh, kw, dh,
|
||||||
|
dw, ph, pw, sh, sw);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<AvgPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
|
||||||
|
pw, sh, sw)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// see operators/element_wise.h
|
// see operators/element_wise.h
|
||||||
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
|
||||||
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
|
||||||
|
|
|
@ -50,6 +50,14 @@ void init_graph_builder(py::module &m) {
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor,
|
||||||
float, float, bool>(&Handler::batchNorm),
|
float, float, bool>(&Handler::batchNorm),
|
||||||
policy::move)
|
policy::move)
|
||||||
|
.def("maxPool",
|
||||||
|
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||||
|
int, int>(&Handler::maxPool),
|
||||||
|
policy::move)
|
||||||
|
.def("avgPool",
|
||||||
|
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||||
|
int, int>(&Handler::avgPool),
|
||||||
|
policy::move)
|
||||||
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
.def("add", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::add),
|
||||||
policy::move)
|
policy::move)
|
||||||
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
.def("sub", py::overload_cast<Tensor, Tensor, Tensor>(&Handler::sub),
|
||||||
|
|
|
@ -5,12 +5,12 @@ namespace infini {
|
||||||
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
||||||
Tensor output, int kh, int kw, int dh, int dw, int ph,
|
Tensor output, int kh, int kw, int dh, int dw, int ph,
|
||||||
int pw, int sh, int sw)
|
int pw, int sh, int sw)
|
||||||
: OperatorObj(optype, {input}, {output}), kh(kh), kw(kw), dh(dh), dw(dw),
|
: OperatorObj(optype, {input}, {output}),
|
||||||
ph(ph), pw(pw), sh(sh), sw(sw) {
|
|
||||||
n = input->getDims()[0];
|
|
||||||
c = input->getDims()[1];
|
|
||||||
h = input->getDims()[2], w = input->getDims()[3];
|
|
||||||
|
|
||||||
|
kh(kh), kw(kw), dh(dh), dw(dw), ph(ph), pw(pw), sh(sh), sw(sw),
|
||||||
|
|
||||||
|
n(input->getDims()[0]), c(input->getDims()[1]), h(input->getDims()[2]),
|
||||||
|
w(input->getDims()[3]) {
|
||||||
IT_ASSERT(checkValid(graph));
|
IT_ASSERT(checkValid(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue