feat: 导出 Gather Concat 到 onnx

- 并优化 python 代码

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-21 13:10:54 +08:00
parent 9d9fbd44af
commit 6b7af7077b
2 changed files with 29 additions and 32 deletions

View File

@ -336,9 +336,10 @@ def to_onnx(graph: backend.GraphHandler):
self.count_op[ty] += 1
return ty, name
def push_output(self, name: str, tensor: backend.Tensor) -> None:
def push_output(self, name: str, tensor: backend.Tensor) -> str:
self.names[tensor] = name
# TODO 需要判断全图输出并保存到 outputs
return name
def push_input(self, tensor: backend.Tensor) -> str:
name = self.names.get(tensor)
@ -375,17 +376,17 @@ def to_onnx(graph: backend.GraphHandler):
ops = graph.operators() # 图中所有算子(节点)
context = Context()
ctx = Context()
for op in ops:
ty, name = context.name_op(op)
inputs = op.inputs()
outputs = op.outputs()
ty, name = ctx.name_op(op)
inputs = [ctx.push_input(it) for it in op.inputs()]
outputs = [
ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs())
]
if ty == backend.OpType.Matmul:
context.push_output(name, outputs[0])
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
context.push_node(make_node("MatMul", [a, b], [name], name))
ctx.push_node(make_node("MatMul", inputs, outputs, name))
elif ty == backend.OpType.BatchNorm:
raise Exception("TODO")
elif ty == backend.OpType.MaxPool:
@ -398,12 +399,6 @@ def to_onnx(graph: backend.GraphHandler):
backend.OpType.Mul,
backend.OpType.Div,
backend.OpType.Pow,
]:
context.push_output(name, outputs[0])
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
context.push_node(make_node(ty.name, [a, b], [name], name))
elif ty in [
backend.OpType.Relu,
backend.OpType.Sigmoid,
backend.OpType.Tanh,
@ -411,14 +406,11 @@ def to_onnx(graph: backend.GraphHandler):
backend.OpType.Abs,
backend.OpType.Identity,
]:
context.push_output(name, outputs[0])
x = context.push_input(inputs[0])
context.push_node(make_node(ty.name, [x], [name], name))
ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpType.Flatten:
raise Exception("TODO")
elif ty == backend.OpType.Reshape:
context.push_output(name, outputs[0])
data = context.push_input(inputs[0])
data = ctx.push_input(inputs[0])
# shape = context.push_data_input(
# name,
# "shape",
@ -429,13 +421,11 @@ def to_onnx(graph: backend.GraphHandler):
# context.push_node(make_node(ty.name, [data, shape], [name], name))
raise Exception("TODO")
elif ty == backend.OpType.Concat:
context.push_output(name, outputs[0])
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
axis = backend.concat_dim_of(op)
context.push_node(make_node("Concat", [a, b], [name], name, axis=axis))
axis = backend.concat_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.Gather:
raise Exception("TODO")
axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.ReduceMean:
raise Exception("TODO")
elif ty == backend.OpType.Slice:
@ -446,13 +436,13 @@ def to_onnx(graph: backend.GraphHandler):
raise Exception("Unsupported OpType {}".format(ty.name))
print()
print(context.names)
print(ctx.names)
print()
print(context.inputs)
print(ctx.inputs)
print()
print(context.outputs)
print(ctx.outputs)
print()
print(context.nodes)
print(ctx.nodes)
def parse_onnx(model: ModelProto):

View File

@ -1,5 +1,6 @@
#include "core/graph_handler.h"
#include "operators/concat.h"
#include "operators/gather.h"
#include <pybind11/stl.h>
#ifdef USE_CUDA
@ -91,18 +92,24 @@ static int tensor_dtype(Tensor t) {
IT_ASSERT(false, "Unsupported data type");
}
static int concat_dim_of(Operator op) {
static int concat_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat);
return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
}
static int gather_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Gather);
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
}
void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.def("tensor_dtype", &tensor_dtype)
.def("concat_dim_of", &concat_dim_of);
.def("concat_axis_of", &concat_axis_of)
.def("gather_axis_of", &gather_axis_of);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime");