feat: 导出 ReduceMean 到 onnx

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-22 11:26:48 +08:00
parent bb9b62b169
commit 71a87c27d1
3 changed files with 29 additions and 11 deletions

View File

@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj {
int numOutputs() const override { return 1; }
bool isReduced(int idx) const;
const set<int> &getAxes() const { return axes; }
bool getKeepDims() const { return keepDims; }
private:

View File

@ -456,8 +456,8 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
name,
"shape",
TensorProto.INT32,
shape=[len(shape)],
vals=shape,
[len(shape)],
shape,
)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
@ -468,7 +468,11 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.ReduceMean:
raise Exception("TODO")
axes = backend.reduce_mean_axes_of(op)
inputs.append(
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
)
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
elif ty == backend.OpType.Slice:
raise Exception("TODO")
elif ty == backend.OpType.Pad:

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/concat.h"
#include "operators/gather.h"
#include "operators/reduce_mean.h"
#include "operators/reshape.h"
#include <pybind11/stl.h>
@ -24,7 +25,7 @@ void register_operator_timer(py::module &m) {
#endif
}
void init_values(py::module &m) {
void export_values(py::module &m) {
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
py::enum_<ActType>(m, "ActType")
@ -103,20 +104,31 @@ static int gather_axis_of(Operator op) {
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
}
static vector<int> reduce_mean_axes_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
auto &set = dynamic_cast<const ReduceMeanObj *>(op.get())->getAxes();
return vector(set.begin(), set.end());
}
static Shape reshape_shape_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Reshape);
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
}
void init_graph_builder(py::module &m) {
void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of)
.FUNCTION(concat_axis_of)
.FUNCTION(gather_axis_of)
.FUNCTION(reduce_mean_axes_of);
#undef FUNCTION
}
void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.def("tensor_dtype", &tensor_dtype)
.def("reshape_shape_of", &reshape_shape_of)
.def("concat_axis_of", &concat_axis_of)
.def("gather_axis_of", &gather_axis_of);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime");
@ -166,6 +178,7 @@ void init_graph_builder(py::module &m) {
PYBIND11_MODULE(backend, m) {
infini::register_operator_timer(m);
infini::init_values(m);
infini::export_values(m);
infini::export_functions(m);
infini::init_graph_builder(m);
}