forked from jiuyuan/InfiniTensor
feat: 导出 ReduceMean 到 onnx
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
bb9b62b169
commit
71a87c27d1
|
@ -30,6 +30,7 @@ class ReduceMeanObj : public OperatorObj {
|
|||
int numOutputs() const override { return 1; }
|
||||
|
||||
bool isReduced(int idx) const;
|
||||
const set<int> &getAxes() const { return axes; }
|
||||
bool getKeepDims() const { return keepDims; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -456,8 +456,8 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
name,
|
||||
"shape",
|
||||
TensorProto.INT32,
|
||||
shape=[len(shape)],
|
||||
vals=shape,
|
||||
[len(shape)],
|
||||
shape,
|
||||
)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
|
@ -468,7 +468,11 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
axis = backend.gather_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.ReduceMean:
|
||||
raise Exception("TODO")
|
||||
axes = backend.reduce_mean_axes_of(op)
|
||||
inputs.append(
|
||||
ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes)
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1))
|
||||
elif ty == backend.OpType.Slice:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Pad:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
|
@ -24,7 +25,7 @@ void register_operator_timer(py::module &m) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void init_values(py::module &m) {
|
||||
void export_values(py::module &m) {
|
||||
#define VALUE(TYPE, NAME) value(#NAME, TYPE::NAME)
|
||||
|
||||
py::enum_<ActType>(m, "ActType")
|
||||
|
@ -103,20 +104,31 @@ static int gather_axis_of(Operator op) {
|
|||
return dynamic_cast<const GatherObj *>(op.get())->getAxis();
|
||||
}
|
||||
|
||||
static vector<int> reduce_mean_axes_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::ReduceMean);
|
||||
auto &set = dynamic_cast<const ReduceMeanObj *>(op.get())->getAxes();
|
||||
return vector(set.begin(), set.end());
|
||||
}
|
||||
|
||||
static Shape reshape_shape_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Reshape);
|
||||
return dynamic_cast<const ReshapeObj *>(op.get())->getShape();
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(reduce_mean_axes_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
.def("tensor_dtype", &tensor_dtype)
|
||||
.def("reshape_shape_of", &reshape_shape_of)
|
||||
.def("concat_axis_of", &concat_axis_of)
|
||||
.def("gather_axis_of", &gather_axis_of);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
|
@ -166,6 +178,7 @@ void init_graph_builder(py::module &m) {
|
|||
|
||||
PYBIND11_MODULE(backend, m) {
|
||||
infini::register_operator_timer(m);
|
||||
infini::init_values(m);
|
||||
infini::export_values(m);
|
||||
infini::export_functions(m);
|
||||
infini::init_graph_builder(m);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue