feat: 导出 Gather Concat 到 onnx

- 并优化 python 代码

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-21 13:10:54 +08:00
parent 9d9fbd44af
commit 6b7af7077b
2 changed files with 29 additions and 32 deletions

View File

@ -336,9 +336,10 @@ def to_onnx(graph: backend.GraphHandler):
self.count_op[ty] += 1 self.count_op[ty] += 1
return ty, name return ty, name
def push_output(self, name: str, tensor: backend.Tensor) -> None: def push_output(self, name: str, tensor: backend.Tensor) -> str:
self.names[tensor] = name self.names[tensor] = name
# TODO 需要判断全图输出并保存到 outputs # TODO 需要判断全图输出并保存到 outputs
return name
def push_input(self, tensor: backend.Tensor) -> str: def push_input(self, tensor: backend.Tensor) -> str:
name = self.names.get(tensor) name = self.names.get(tensor)
@ -375,17 +376,17 @@ def to_onnx(graph: backend.GraphHandler):
ops = graph.operators() # 图中所有算子(节点) ops = graph.operators() # 图中所有算子(节点)
context = Context() ctx = Context()
for op in ops: for op in ops:
ty, name = context.name_op(op) ty, name = ctx.name_op(op)
inputs = op.inputs() inputs = [ctx.push_input(it) for it in op.inputs()]
outputs = op.outputs() outputs = [
ctx.push_output("{}_{}".format(name, i), it)
for (i, it) in enumerate(op.outputs())
]
if ty == backend.OpType.Matmul: if ty == backend.OpType.Matmul:
context.push_output(name, outputs[0]) ctx.push_node(make_node("MatMul", inputs, outputs, name))
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
context.push_node(make_node("MatMul", [a, b], [name], name))
elif ty == backend.OpType.BatchNorm: elif ty == backend.OpType.BatchNorm:
raise Exception("TODO") raise Exception("TODO")
elif ty == backend.OpType.MaxPool: elif ty == backend.OpType.MaxPool:
@ -398,12 +399,6 @@ def to_onnx(graph: backend.GraphHandler):
backend.OpType.Mul, backend.OpType.Mul,
backend.OpType.Div, backend.OpType.Div,
backend.OpType.Pow, backend.OpType.Pow,
]:
context.push_output(name, outputs[0])
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
context.push_node(make_node(ty.name, [a, b], [name], name))
elif ty in [
backend.OpType.Relu, backend.OpType.Relu,
backend.OpType.Sigmoid, backend.OpType.Sigmoid,
backend.OpType.Tanh, backend.OpType.Tanh,
@ -411,14 +406,11 @@ def to_onnx(graph: backend.GraphHandler):
backend.OpType.Abs, backend.OpType.Abs,
backend.OpType.Identity, backend.OpType.Identity,
]: ]:
context.push_output(name, outputs[0]) ctx.push_node(make_node(ty.name, inputs, outputs, name))
x = context.push_input(inputs[0])
context.push_node(make_node(ty.name, [x], [name], name))
elif ty == backend.OpType.Flatten: elif ty == backend.OpType.Flatten:
raise Exception("TODO") raise Exception("TODO")
elif ty == backend.OpType.Reshape: elif ty == backend.OpType.Reshape:
context.push_output(name, outputs[0]) data = ctx.push_input(inputs[0])
data = context.push_input(inputs[0])
# shape = context.push_data_input( # shape = context.push_data_input(
# name, # name,
# "shape", # "shape",
@ -429,13 +421,11 @@ def to_onnx(graph: backend.GraphHandler):
# context.push_node(make_node(ty.name, [data, shape], [name], name)) # context.push_node(make_node(ty.name, [data, shape], [name], name))
raise Exception("TODO") raise Exception("TODO")
elif ty == backend.OpType.Concat: elif ty == backend.OpType.Concat:
context.push_output(name, outputs[0]) axis = backend.concat_axis_of(op)
a = context.push_input(inputs[0]) ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
b = context.push_input(inputs[1])
axis = backend.concat_dim_of(op)
context.push_node(make_node("Concat", [a, b], [name], name, axis=axis))
elif ty == backend.OpType.Gather: elif ty == backend.OpType.Gather:
raise Exception("TODO") axis = backend.gather_axis_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis))
elif ty == backend.OpType.ReduceMean: elif ty == backend.OpType.ReduceMean:
raise Exception("TODO") raise Exception("TODO")
elif ty == backend.OpType.Slice: elif ty == backend.OpType.Slice:
@ -446,13 +436,13 @@ def to_onnx(graph: backend.GraphHandler):
raise Exception("Unsupported OpType {}".format(ty.name)) raise Exception("Unsupported OpType {}".format(ty.name))
print() print()
print(context.names) print(ctx.names)
print() print()
print(context.inputs) print(ctx.inputs)
print() print()
print(context.outputs) print(ctx.outputs)
print() print()
print(context.nodes) print(ctx.nodes)
def parse_onnx(model: ModelProto): def parse_onnx(model: ModelProto):

View File

@ -1,5 +1,6 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/gather.h"
#include <pybind11/stl.h> #include <pybind11/stl.h>
#ifdef USE_CUDA #ifdef USE_CUDA
@ -91,18 +92,24 @@ static int tensor_dtype(Tensor t) {
IT_ASSERT(false, "Unsupported data type"); IT_ASSERT(false, "Unsupported data type");
} }
static int concat_dim_of(Operator op) { static int concat_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Concat); IT_ASSERT(op->getOpType() == OpType::Concat);
return reinterpret_cast<const ConcatObj *>(op.get())->getDim(); return reinterpret_cast<const ConcatObj *>(op.get())->getDim();
} }
static int gather_axis_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Gather);
return reinterpret_cast<const GatherObj *>(op.get())->getAxis();
}
void init_graph_builder(py::module &m) { void init_graph_builder(py::module &m) {
using Handler = GraphHandlerObj; using Handler = GraphHandlerObj;
m.def("cpu_runtime", &CpuRuntimeObj::getInstance) m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
.def("tensor_dtype", &tensor_dtype) .def("tensor_dtype", &tensor_dtype)
.def("concat_dim_of", &concat_dim_of); .def("concat_axis_of", &concat_axis_of)
.def("gather_axis_of", &gather_axis_of);
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime"); py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>( py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime"); m, "CpuRuntime");