feat: 封装上下文对象以复用建图代码

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-20 12:12:58 +08:00
parent 0833a2f779
commit eff4c14a85
2 changed files with 46 additions and 77 deletions

View File

@ -2,7 +2,7 @@
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
from onnx.helper import make_node
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any
from typing import Dict, List, Any, Tuple
from functools import reduce
runtime = backend.cpu_runtime()
@ -306,23 +306,46 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
def to_onnx(graph: backend.GraphHandler):
class Context:
names: Dict[Any, str] = dict() # 记录所有对象的名字,包括张量和算子
nodes: List[NodeProto] = [] # 保存所有算子
count_op: Dict[backend.OpType, int] = dict() # 统计每个算子出现次数,用于命名
count_in = 0 # 统计输入张量数量,用于命名
count_out = 0 # 统计输出张量数量,用于命名
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
ty = op.op_type()
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
self.names[op] = name
self.count_op[ty] += 1
return ty, name
def push_output(self, name: str, tensor: backend.Tensor) -> None:
self.names[tensor] = name
def push_input(self, tensor: backend.Tensor) -> str:
name = self.names.get(tensor)
if name is None:
self.count_in += 1
name = "input{}".format(self.count_in)
self.names[tensor] = name
return name
def push_node(self, node: NodeProto) -> None:
self.nodes.append(node)
# 拓扑排序
if not graph.topo_sort():
raise Exception("Sorting fails")
ops = graph.operators()
ops = graph.operators() # 图中所有算子(节点)
names: Dict[Any, str] = dict()
nodes: List[NodeProto] = []
count_op: Dict[backend.OpType, int] = dict()
count_in = 0
context = Context()
for op in ops:
ty = op.op_type()
name = "{}{}".format(ty.name, count_op.setdefault(ty, 0) + 1)
ty, name = context.name_op(op)
inputs = op.inputs()
outputs = op.outputs()
names[op] = name
count_op[ty] += 1
if ty == backend.OpType.Matmul:
raise Exception("TODO")
elif ty == backend.OpType.BatchNorm:
@ -331,71 +354,17 @@ def to_onnx(graph: backend.GraphHandler):
raise Exception("TODO")
elif ty == backend.OpType.AvgPool:
raise Exception("TODO")
elif ty == backend.OpType.Add:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Add", [a, b], [name], name))
elif ty == backend.OpType.Sub:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Sub", [a, b], [name], name))
elif ty == backend.OpType.Mul:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Mul", [a, b], [name], name))
elif ty == backend.OpType.Div:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Div", [a, b], [name], name))
elif ty == backend.OpType.Pow:
names[outputs[0]] = name
if inputs[0] in names:
a = names[inputs[0]]
else:
count_in += 1
a = "input{}".format(count_in)
if inputs[1] in names:
b = names[inputs[1]]
else:
count_in += 1
b = "input{}".format(count_in)
nodes.append(make_node("Pow", [a, b], [name], name))
elif ty in [
backend.OpType.Add,
backend.OpType.Sub,
backend.OpType.Mul,
backend.OpType.Div,
backend.OpType.Pow,
]:
context.push_output(name, outputs[0])
a = context.push_input(inputs[0])
b = context.push_input(inputs[1])
context.push_node(make_node(ty.name, [a, b], [name], name))
elif ty == backend.OpType.Relu:
raise Exception("TODO")
elif ty == backend.OpType.Sigmoid:
@ -425,7 +394,7 @@ def to_onnx(graph: backend.GraphHandler):
else:
raise Exception("Unsupported OpType {}".format(ty.name))
print(names)
print(context.names)
def parse_onnx(model: ModelProto):

View File

@ -77,7 +77,7 @@ void init_graph_builder(py::module &m) {
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntime");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
.def("src", &TensorObj::getOutputOf, policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic)