diff --git a/include/operators/pooling.h b/include/operators/pooling.h index c14bb8ad..1a1a6edf 100644 --- a/include/operators/pooling.h +++ b/include/operators/pooling.h @@ -39,20 +39,22 @@ class PoolingObj : public OperatorObj { optional> inferShape(const TensorVec &inputs) const override; std::string toString() const override; - int numInputs() const override { return 1; } - int numOutputs() const override { return 1; } + inline int numInputs() const override { return 1; } + inline int numOutputs() const override { return 1; } - int getKh() const { return kh; } - int getKw() const { return kw; } - int getDh() const { return dh; } - int getDw() const { return dw; } - int getPh() const { return ph; } - int getPw() const { return pw; } - int getSh() const { return sh; } - int getSw() const { return sw; } + inline int getKh() const { return kh; } + inline int getKw() const { return kw; } + inline int getDh() const { return dh; } + inline int getDw() const { return dw; } + inline int getPh() const { return ph; } + inline int getPw() const { return pw; } + inline int getSh() const { return sh; } + inline int getSw() const { return sw; } - auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); } - auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); } + inline auto getPadStrideDilation() const { + return tuple(ph, pw, sh, sw, dh, dw); + } + inline auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); } private: vector getWorkloadVector() const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 9af10f6e..0915b9bc 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -504,9 +504,32 @@ class OnnxStub: ) ) elif ty == backend.OpType.MaxPool: - raise Exception("TODO") + kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op) + ctx.push_node( + make_node( + ty.name, + inputs, + outputs, + name, + kernel_shape=[kh, kw], + pads=[ph, pw], + dilations=[dh, dw], + strides=[sh, sw], + ) + ) elif ty == backend.OpType.AvgPool: - raise Exception("TODO") + kh, kw, dh, dw, ph, pw, sh, sw = backend.pool_attrs_of(op) + ctx.push_node( + make_node( + "AveragePool", + inputs, + outputs, + name, + kernel_shape=[kh, kw], + pads=[ph, pw], + strides=[sh, sw], + ) + ) elif ty in [ backend.OpType.Add, backend.OpType.Sub, diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 49164e07..c9547f1c 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -3,6 +3,7 @@ #include "operators/concat.h" #include "operators/conv.h" #include "operators/gather.h" +#include "operators/pooling.h" #include "operators/reduce_mean.h" #include "operators/reshape.h" #include @@ -115,6 +116,16 @@ static std::tuple batch_norm_attrs_of(Operator op) { batchnorm->getTraining()); } +static std::tuple +pool_attrs_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::MaxPool || + op->getOpType() == OpType::AvgPool); + auto pool = dynamic_cast(op.get()); + return std::make_tuple(pool->getKh(), pool->getKw(), pool->getDh(), + pool->getDw(), pool->getPh(), pool->getPw(), + pool->getSh(), pool->getSw()); +} + static int concat_axis_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Concat); return dynamic_cast(op.get())->getDim(); @@ -144,6 +155,7 @@ void export_functions(py::module &m) { #endif .FUNCTION(conv_attrs_of) .FUNCTION(batch_norm_attrs_of) + .FUNCTION(pool_attrs_of) .FUNCTION(tensor_dtype) .FUNCTION(reshape_shape_of) .FUNCTION(concat_axis_of)