forked from jiuyuan/InfiniTensor
feat: 导出 pool 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
8a871c3773
commit
e294e46436
|
@ -39,20 +39,22 @@ class PoolingObj : public OperatorObj {
|
||||||
|
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
int numInputs() const override { return 1; }
|
inline int numInputs() const override { return 1; }
|
||||||
int numOutputs() const override { return 1; }
|
inline int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
int getKh() const { return kh; }
|
inline int getKh() const { return kh; }
|
||||||
int getKw() const { return kw; }
|
inline int getKw() const { return kw; }
|
||||||
int getDh() const { return dh; }
|
inline int getDh() const { return dh; }
|
||||||
int getDw() const { return dw; }
|
inline int getDw() const { return dw; }
|
||||||
int getPh() const { return ph; }
|
inline int getPh() const { return ph; }
|
||||||
int getPw() const { return pw; }
|
inline int getPw() const { return pw; }
|
||||||
int getSh() const { return sh; }
|
inline int getSh() const { return sh; }
|
||||||
int getSw() const { return sw; }
|
inline int getSw() const { return sw; }
|
||||||
|
|
||||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
inline auto getPadStrideDilation() const {
|
||||||
auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
|
return tuple(ph, pw, sh, sw, dh, dw);
|
||||||
|
}
|
||||||
|
inline auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<int> getWorkloadVector() const override;
|
vector<int> getWorkloadVector() const override;
|
||||||
|
|
|
@ -504,9 +504,32 @@ class OnnxStub:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif ty == backend.OpType.MaxPool:
|
elif ty == backend.OpType.MaxPool:
|
||||||
raise Exception("TODO")
|
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
|
||||||
|
ctx.push_node(
|
||||||
|
make_node(
|
||||||
|
ty.name,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
name,
|
||||||
|
kernel_shape=[kh, kw],
|
||||||
|
pads=[ph, pw],
|
||||||
|
dilations=[dh, dw],
|
||||||
|
strides=[sh, sw],
|
||||||
|
)
|
||||||
|
)
|
||||||
elif ty == backend.OpType.AvgPool:
|
elif ty == backend.OpType.AvgPool:
|
||||||
raise Exception("TODO")
|
kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op)
|
||||||
|
ctx.push_node(
|
||||||
|
make_node(
|
||||||
|
"AveragePool",
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
name,
|
||||||
|
kernel_shape=[kh, kw],
|
||||||
|
pads=[ph, pw],
|
||||||
|
strides=[sh, sw],
|
||||||
|
)
|
||||||
|
)
|
||||||
elif ty in [
|
elif ty in [
|
||||||
backend.OpType.Add,
|
backend.OpType.Add,
|
||||||
backend.OpType.Sub,
|
backend.OpType.Sub,
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
#include "operators/conv.h"
|
#include "operators/conv.h"
|
||||||
#include "operators/gather.h"
|
#include "operators/gather.h"
|
||||||
|
#include "operators/pooling.h"
|
||||||
#include "operators/reduce_mean.h"
|
#include "operators/reduce_mean.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
@ -115,6 +116,16 @@ static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||||
batchnorm->getTraining());
|
batchnorm->getTraining());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::tuple<int, int, int, int, int, int, int, int>
|
||||||
|
pool_attrs_of(Operator op) {
|
||||||
|
IT_ASSERT(op->getOpType() == OpType::MaxPool ||
|
||||||
|
op->getOpType() == OpType::AvgPool);
|
||||||
|
auto pool = dynamic_cast<const PoolingObj *>(op.get());
|
||||||
|
return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(),
|
||||||
|
pool->getDw(), pool->getPh(), pool->getPw(),
|
||||||
|
pool->getSh(), pool->getSw());
|
||||||
|
}
|
||||||
|
|
||||||
static int concat_axis_of(Operator op) {
|
static int concat_axis_of(Operator op) {
|
||||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||||
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
return dynamic_cast<const ConcatObj *>(op.get())->getDim();
|
||||||
|
@ -144,6 +155,7 @@ void export_functions(py::module &m) {
|
||||||
#endif
|
#endif
|
||||||
.FUNCTION(conv_attrs_of)
|
.FUNCTION(conv_attrs_of)
|
||||||
.FUNCTION(batch_norm_attrs_of)
|
.FUNCTION(batch_norm_attrs_of)
|
||||||
|
.FUNCTION(pool_attrs_of)
|
||||||
.FUNCTION(tensor_dtype)
|
.FUNCTION(tensor_dtype)
|
||||||
.FUNCTION(reshape_shape_of)
|
.FUNCTION(reshape_shape_of)
|
||||||
.FUNCTION(concat_axis_of)
|
.FUNCTION(concat_axis_of)
|
||||||
|
|
Loading…
Reference in New Issue