forked from jiuyuan/InfiniTensor
feat: 支持导出 membound
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
0edd138919
commit
725f9260cf
|
@ -512,6 +512,8 @@ class OnnxStub:
|
|||
),
|
||||
):
|
||||
tensors[name] = tensor
|
||||
elif node.op_type == "MemBound":
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
|
@ -812,6 +814,17 @@ class OnnxStub:
|
|||
ctx.push_data_input(name, "max", TensorProto.FLOAT, [], [])
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.MemBound:
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
domain="nnet",
|
||||
expr=backend.membound_expr_of(op),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise Exception("Unsupported OpType", ty)
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "core/mutator.h"
|
||||
#include "core/search_engine.h"
|
||||
#include "nnet/Visitor/Serializer.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/membound.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
@ -229,6 +231,12 @@ static vector<int> transpose_permute_of(Operator op) {
|
|||
return dynamic_cast<const TransposeObj *>(op.get())->getPermute();
|
||||
}
|
||||
|
||||
static string membound_expr_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::MemBound);
|
||||
return *nnet::Serializer().toString(
|
||||
dynamic_cast<const MemBoundObj *>(op.get())->getNnetExpr());
|
||||
}
|
||||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
||||
|
@ -257,7 +265,8 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(transpose_permute_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(split_axis_of)
|
||||
.FUNCTION(gather_axis_of);
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(membound_expr_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue