forked from jiuyuan/InfiniTensor
feat: 封装上下文对象以复用建图代码
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
0833a2f779
commit
eff4c14a85
|
@ -2,7 +2,7 @@
|
|||
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
|
||||
from onnx.helper import make_node
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Tuple
|
||||
from functools import reduce
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
|
@ -306,23 +306,46 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
|||
|
||||
|
||||
def to_onnx(graph: backend.GraphHandler):
|
||||
class Context:
|
||||
names: Dict[Any, str] = dict() # 记录所有对象的名字,包括张量和算子
|
||||
nodes: List[NodeProto] = [] # 保存所有算子
|
||||
count_op: Dict[backend.OpType, int] = dict() # 统计每个算子出现次数,用于命名
|
||||
count_in = 0 # 统计输入张量数量,用于命名
|
||||
count_out = 0 # 统计输出张量数量,用于命名
|
||||
|
||||
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||
ty = op.op_type()
|
||||
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||
self.names[op] = name
|
||||
self.count_op[ty] += 1
|
||||
return ty, name
|
||||
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> None:
|
||||
self.names[tensor] = name
|
||||
|
||||
def push_input(self, tensor: backend.Tensor) -> str:
|
||||
name = self.names.get(tensor)
|
||||
if name is None:
|
||||
self.count_in += 1
|
||||
name = "input{}".format(self.count_in)
|
||||
self.names[tensor] = name
|
||||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
self.nodes.append(node)
|
||||
|
||||
# 拓扑排序
|
||||
if not graph.topo_sort():
|
||||
raise Exception("Sorting fails")
|
||||
|
||||
ops = graph.operators()
|
||||
ops = graph.operators() # 图中所有算子(节点)
|
||||
|
||||
names: Dict[Any, str] = dict()
|
||||
nodes: List[NodeProto] = []
|
||||
count_op: Dict[backend.OpType, int] = dict()
|
||||
count_in = 0
|
||||
context = Context()
|
||||
|
||||
for op in ops:
|
||||
ty = op.op_type()
|
||||
name = "{}{}".format(ty.name, count_op.setdefault(ty, 0) + 1)
|
||||
ty, name = context.name_op(op)
|
||||
inputs = op.inputs()
|
||||
outputs = op.outputs()
|
||||
names[op] = name
|
||||
count_op[ty] += 1
|
||||
if ty == backend.OpType.Matmul:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.BatchNorm:
|
||||
|
@ -331,71 +354,17 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.AvgPool:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Add:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Add", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Sub:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Sub", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Mul:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Mul", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Div:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Div", [a, b], [name], name))
|
||||
elif ty == backend.OpType.Pow:
|
||||
names[outputs[0]] = name
|
||||
if inputs[0] in names:
|
||||
a = names[inputs[0]]
|
||||
else:
|
||||
count_in += 1
|
||||
a = "input{}".format(count_in)
|
||||
if inputs[1] in names:
|
||||
b = names[inputs[1]]
|
||||
else:
|
||||
count_in += 1
|
||||
b = "input{}".format(count_in)
|
||||
nodes.append(make_node("Pow", [a, b], [name], name))
|
||||
elif ty in [
|
||||
backend.OpType.Add,
|
||||
backend.OpType.Sub,
|
||||
backend.OpType.Mul,
|
||||
backend.OpType.Div,
|
||||
backend.OpType.Pow,
|
||||
]:
|
||||
context.push_output(name, outputs[0])
|
||||
a = context.push_input(inputs[0])
|
||||
b = context.push_input(inputs[1])
|
||||
context.push_node(make_node(ty.name, [a, b], [name], name))
|
||||
elif ty == backend.OpType.Relu:
|
||||
raise Exception("TODO")
|
||||
elif ty == backend.OpType.Sigmoid:
|
||||
|
@ -425,7 +394,7 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
else:
|
||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
|
||||
print(names)
|
||||
print(context.names)
|
||||
|
||||
|
||||
def parse_onnx(model: ModelProto):
|
||||
|
|
|
@ -77,7 +77,7 @@ void init_graph_builder(py::module &m) {
|
|||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
|
|
Loading…
Reference in New Issue