feat: 导出 pool 到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-27 09:32:08 +08:00
parent 8a871c3773
commit e294e46436
3 changed files with 51 additions and 14 deletions

View File

@ -39,20 +39,22 @@ class PoolingObj : public OperatorObj {
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } inline int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } inline int numOutputs() const override { return 1; }
int getKh() const { return kh; } inline int getKh() const { return kh; }
int getKw() const { return kw; } inline int getKw() const { return kw; }
int getDh() const { return dh; } inline int getDh() const { return dh; }
int getDw() const { return dw; } inline int getDw() const { return dw; }
int getPh() const { return ph; } inline int getPh() const { return ph; }
int getPw() const { return pw; } inline int getPw() const { return pw; }
int getSh() const { return sh; } inline int getSh() const { return sh; }
int getSw() const { return sw; } inline int getSw() const { return sw; }
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); } inline auto getPadStrideDilation() const {
auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); } return tuple(ph, pw, sh, sw, dh, dw);
}
inline auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;

View File

@ -504,9 +504,32 @@ class OnnxStub:
) )
) )
elif ty == backend.OpType.MaxPool: elif ty == backend.OpType.MaxPool:
raise Exception("TODO") kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
ctx.push_node(
make_node(
ty.name,
inputs,
outputs,
name,
kernel_shape=[kh, kw],
pads=[ph, pw],
dilations=[dh, dw],
strides=[sh, sw],
)
)
elif ty == backend.OpType.AvgPool: elif ty == backend.OpType.AvgPool:
raise Exception("TODO") kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
ctx.push_node(
make_node(
"AveragePool",
inputs,
outputs,
name,
kernel_shape=[kh, kw],
pads=[ph, pw],
strides=[sh, sw],
)
)
elif ty in [ elif ty in [
backend.OpType.Add, backend.OpType.Add,
backend.OpType.Sub, backend.OpType.Sub,

View File

@ -3,6 +3,7 @@
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h" #include "operators/reduce_mean.h"
#include "operators/reshape.h" #include "operators/reshape.h"
#include <pybind11/stl.h> #include <pybind11/stl.h>
@ -115,6 +116,16 @@ static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
batchnorm->getTraining()); batchnorm->getTraining());
} }
static std::tuple<int, int, int, int, int, int, int, int>
pool_attrs_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::MaxPool ||
op->getOpType() == OpType::AvgPool);
auto pool = dynamic_cast<const PoolingObj *>(op.get());
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
pool->getDw(), pool->getPh(), pool->getPw(),
pool->getSh(), pool->getSw());
}
static int concat_axis_of(Operator op) { static int concat_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat); IT_ASSERT(op->getOpType() == OpType::Concat);
return dynamic_cast<const ConcatObj *>(op.get())->getDim(); return dynamic_cast<const ConcatObj *>(op.get())->getDim();
@ -144,6 +155,7 @@ void export_functions(py::module &m) {
#endif #endif
.FUNCTION(conv_attrs_of) .FUNCTION(conv_attrs_of)
.FUNCTION(batch_norm_attrs_of) .FUNCTION(batch_norm_attrs_of)
.FUNCTION(pool_attrs_of)
.FUNCTION(tensor_dtype) .FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of) .FUNCTION(reshape_shape_of)
.FUNCTION(concat_axis_of) .FUNCTION(concat_axis_of)