forked from jiuyuan/InfiniTensor
feat: 导出 Gather Concat 到 onnx
- 并优化 python 代码 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
9d9fbd44af
commit
6b7af7077b
|
@ -336,9 +336,10 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
self.count_op[ty] += 1
|
||||
return ty, name
|
||||
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> None:
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||
self.names[tensor] = name
|
||||
# TODO 需要判断全图输出并保存到 outputs
|
||||
return name
|
||||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
name = self.names.get(tensor)
|
||||
|
@ -375,17 +376,17 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
|
||||
ops = graph.operators() # 图中所有算子(节点)
|
||||
|
||||
context = Context()
|
||||
ctx = Context()
|
||||
|
||||
for op in ops:
|
||||
ty, name = context.name_op(op)
|
||||
inputs = op.inputs()
|
||||
outputs = op.outputs()
|
||||
ty, name = ctx.name_op(op)
|
||||
inputs = [ctx.push_input(it) for it in op.inputs()]
|
||||
outputs = [
|
||||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Matmul:
|
||||
context.push_output(name, outputs[0])
|
||||
a = context.push_input(inputs[0])
|
||||
b = context.push_input(inputs[1])
|
||||
context.push_node(make_node("MatMul", [a, b], [name], name))
|
||||
ctx.push_node(make_node("MatMul", inputs, outputs, name))
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.MaxPool:
|
||||
|
@ -398,12 +399,6 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
backend.OpType.Mul,
|
||||
backend.OpType.Div,
|
||||
backend.OpType.Pow,
|
||||
]:
|
||||
context.push_output(name, outputs[0])
|
||||
a = context.push_input(inputs[0])
|
||||
b = context.push_input(inputs[1])
|
||||
context.push_node(make_node(ty.name, [a, b], [name], name))
|
||||
elif ty in [
|
||||
backend.OpType.Relu,
|
||||
backend.OpType.Sigmoid,
|
||||
backend.OpType.Tanh,
|
||||
|
@ -411,14 +406,11 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
backend.OpType.Abs,
|
||||
backend.OpType.Identity,
|
||||
]:
|
||||
context.push_output(name, outputs[0])
|
||||
x = context.push_input(inputs[0])
|
||||
context.push_node(make_node(ty.name, [x], [name], name))
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.Flatten:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Reshape:
|
||||
context.push_output(name, outputs[0])
|
||||
data = context.push_input(inputs[0])
|
||||
data = ctx.push_input(inputs[0])
|
||||
# shape = context.push_data_input(
|
||||
# name,
|
||||
# "shape",
|
||||
|
@ -429,13 +421,11 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
# context.push_node(make_node(ty.name, [data, shape], [name], name))
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Concat:
|
||||
context.push_output(name, outputs[0])
|
||||
a = context.push_input(inputs[0])
|
||||
b = context.push_input(inputs[1])
|
||||
axis = backend.concat_dim_of(op)
|
||||
context.push_node(make_node("Concat", [a, b], [name], name, axis=axis))
|
||||
axis = backend.concat_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.Gather:
|
||||
raise Exception("TODO")
|
||||
axis = backend.gather_axis_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
|
||||
elif ty == backend.OpType.ReduceMean:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Slice:
|
||||
|
@ -446,13 +436,13 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
|
||||
print()
|
||||
print(context.names)
|
||||
print(ctx.names)
|
||||
print()
|
||||
print(context.inputs)
|
||||
print(ctx.inputs)
|
||||
print()
|
||||
print(context.outputs)
|
||||
print(ctx.outputs)
|
||||
print()
|
||||
print(context.nodes)
|
||||
print(ctx.nodes)
|
||||
|
||||
|
||||
def parse_onnx(model: ModelProto):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "core/graph_handler.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/gather.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
@ -91,18 +92,24 @@ static int tensor_dtype(Tensor t) {
|
|||
IT_ASSERT(false, "Unsupported data type");
|
||||
}
|
||||
|
||||
static int concat_dim_of(Operator op) {
|
||||
static int concat_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Concat);
|
||||
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
|
||||
}
|
||||
|
||||
static int gather_axis_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Gather);
|
||||
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
|
||||
}
|
||||
|
||||
void init_graph_builder(py::module &m) {
|
||||
|
||||
using Handler = GraphHandlerObj;
|
||||
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
.def("tensor_dtype", &tensor_dtype)
|
||||
.def("concat_dim_of", &concat_dim_of);
|
||||
.def("concat_axis_of", &concat_axis_of)
|
||||
.def("gather_axis_of", &gather_axis_of);
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
|
|
Loading…
Reference in New Issue