feat: 支持导出 membound

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-19 10:33:40 +08:00 committed by Liyan Zheng
parent 0edd138919
commit 725f9260cf
2 changed files with 23 additions and 1 deletions

View File

@ -512,6 +512,8 @@ class OnnxStub:
),
):
tensors[name] = tensor
elif node.op_type == "MemBound":
raise Exception('Unsupported operator "{}"'.format(node.op_type))
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
@ -812,6 +814,17 @@ class OnnxStub:
ctx.push_data_input(name, "max", TensorProto.FLOAT, [], [])
)
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.MemBound:
ctx.push_node(
make_node(
ty.name,
inputs,
outputs,
name,
domain="nnet",
expr=backend.membound_expr_of(op),
)
)
else:
raise Exception("Unsupported OpType", ty)

View File

@ -1,12 +1,14 @@
#include "core/graph_handler.h"
#include "core/mutator.h"
#include "core/search_engine.h"
#include "nnet/Visitor/Serializer.h"
#include "nnet/nmutator.h"
#include "operators/batch_norm.h"
#include "operators/concat.h"
#include "operators/conv.h"
#include "operators/gather.h"
#include "operators/matmul.h"
#include "operators/membound.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/reduce_mean.h"
@ -229,6 +231,12 @@ static vector<int> transpose_permute_of(Operator op) {
return dynamic_cast<const TransposeObj *>(op.get())->getPermute();
}
static string membound_expr_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::MemBound);
return *nnet::Serializer().toString(
dynamic_cast<const MemBoundObj *>(op.get())->getNnetExpr());
}
void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
@ -257,7 +265,8 @@ void export_functions(py::module &m) {
.FUNCTION(transpose_permute_of)
.FUNCTION(concat_axis_of)
.FUNCTION(split_axis_of)
.FUNCTION(gather_axis_of);
.FUNCTION(gather_axis_of)
.FUNCTION(membound_expr_of);
#undef FUNCTION
}