diff --git a/include/operators/reduce_mean.h b/include/operators/reduce_mean.h index 76d3454b..ef74cd2e 100644 --- a/include/operators/reduce_mean.h +++ b/include/operators/reduce_mean.h @@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj { int numOutputs() const override { return 1; } bool isReduced(int idx) const; + const set &getAxes() const { return axes; } bool getKeepDims() const { return keepDims; } private: diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index bb155a6c..bd0b926a 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -456,8 +456,8 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: name, "shape", TensorProto.INT32, - shape=[len(shape)], - vals=shape, + [len(shape)], + shape, ) ) ctx.push_node(make_node(ty.name, inputs, outputs, name)) @@ -468,7 +468,11 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: axis = backend.gather_axis_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) elif ty == backend.OpType.ReduceMean: - raise Exception("TODO") + axes = backend.reduce_mean_axes_of(op) + inputs.append( + ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1)) elif ty == backend.OpType.Slice: raise Exception("TODO") elif ty == backend.OpType.Pad: diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index c2d35051..09fa361c 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/concat.h" #include "operators/gather.h" +#include "operators/reduce_mean.h" #include "operators/reshape.h" #include @@ -24,7 +25,7 @@ void register_operator_timer(py::module &m) { #endif } -void init_values(py::module &m) { +void export_values(py::module &m) { #define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME) py::enum_(m, "ActType") @@ -103,20 +104,31 @@ static int gather_axis_of(Operator op) { return dynamic_cast(op.get())->getAxis(); } +static vector reduce_mean_axes_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::ReduceMean); + auto &set = dynamic_cast(op.get())->getAxes(); + return vector(set.begin(), set.end()); +} + static Shape reshape_shape_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Reshape); return dynamic_cast(op.get())->getShape(); } -void init_graph_builder(py::module &m) { +void export_functions(py::module &m) { +#define FUNCTION(NAME) def(#NAME, &NAME) + m.def("cpu_runtime", &CpuRuntimeObj::getInstance) + .FUNCTION(tensor_dtype) + .FUNCTION(reshape_shape_of) + .FUNCTION(concat_axis_of) + .FUNCTION(gather_axis_of) + .FUNCTION(reduce_mean_axes_of); +#undef FUNCTION +} +void init_graph_builder(py::module &m) { using Handler = GraphHandlerObj; - m.def("cpu_runtime", &CpuRuntimeObj::getInstance) - .def("tensor_dtype", &tensor_dtype) - .def("reshape_shape_of", &reshape_shape_of) - .def("concat_axis_of", &concat_axis_of) - .def("gather_axis_of", &gather_axis_of); py::class_>(m, "Runtime"); py::class_, RuntimeObj>( m, "CpuRuntime"); @@ -166,6 +178,7 @@ void init_graph_builder(py::module &m) { PYBIND11_MODULE(backend, m) { infini::register_operator_timer(m); - infini::init_values(m); + infini::export_values(m); + infini::export_functions(m); infini::init_graph_builder(m); }