forked from jiuyuan/InfiniTensor
feat: 导出 pool 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
8a871c3773
commit
e294e46436
|
@ -39,20 +39,22 @@ class PoolingObj : public OperatorObj {
|
|||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
inline int numInputs() const override { return 1; }
|
||||
inline int numOutputs() const override { return 1; }
|
||||
|
||||
int getKh() const { return kh; }
|
||||
int getKw() const { return kw; }
|
||||
int getDh() const { return dh; }
|
||||
int getDw() const { return dw; }
|
||||
int getPh() const { return ph; }
|
||||
int getPw() const { return pw; }
|
||||
int getSh() const { return sh; }
|
||||
int getSw() const { return sw; }
|
||||
inline int getKh() const { return kh; }
|
||||
inline int getKw() const { return kw; }
|
||||
inline int getDh() const { return dh; }
|
||||
inline int getDw() const { return dw; }
|
||||
inline int getPh() const { return ph; }
|
||||
inline int getPw() const { return pw; }
|
||||
inline int getSh() const { return sh; }
|
||||
inline int getSw() const { return sw; }
|
||||
|
||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
|
||||
inline auto getPadStrideDilation() const {
|
||||
return tuple(ph, pw, sh, sw, dh, dw);
|
||||
}
|
||||
inline auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
|
|
@ -504,9 +504,32 @@ class OnnxStub:
|
|||
)
|
||||
)
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
raise Exception("TODO")
|
||||
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
kernel_shape=[kh, kw],
|
||||
pads=[ph, pw],
|
||||
dilations=[dh, dw],
|
||||
strides=[sh, sw],
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpType.AvgPool:
|
||||
raise Exception("TODO")
|
||||
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
"AveragePool",
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
kernel_shape=[kh, kw],
|
||||
pads=[ph, pw],
|
||||
strides=[sh, sw],
|
||||
)
|
||||
)
|
||||
elif ty in [
|
||||
backend.OpType.Add,
|
||||
backend.OpType.Sub,
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
@ -115,6 +116,16 @@ static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
|||
batchnorm->getTraining());
|
||||
}
|
||||
|
||||
static std::tuple<int, int, int, int, int, int, int, int>
|
||||
pool_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::MaxPool ||
|
||||
op->getOpType() == OpType::AvgPool);
|
||||
auto pool = dynamic_cast<const PoolingObj *>(op.get());
|
||||
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
|
||||
pool->getDw(), pool->getPh(), pool->getPw(),
|
||||
pool->getSh(), pool->getSw());
|
||||
}
|
||||
|
||||
static int concat_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||
|
@ -144,6 +155,7 @@ void export_functions(py::module &m) {
|
|||
#endif
|
||||
.FUNCTION(conv_attrs_of)
|
||||
.FUNCTION(batch_norm_attrs_of)
|
||||
.FUNCTION(pool_attrs_of)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
|
|
Loading…
Reference in New Issue