forked from jiuyuan/InfiniTensor
feat: 封装上下文对象以复用建图代码
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
0833a2f779
commit
eff4c14a85
|
@ -2,7 +2,7 @@
|
||||||
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
|
from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto
|
||||||
from onnx.helper import make_node
|
from onnx.helper import make_node
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any, Tuple
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
@ -306,23 +306,46 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
||||||
|
|
||||||
|
|
||||||
def to_onnx(graph: backend.GraphHandler):
|
def to_onnx(graph: backend.GraphHandler):
|
||||||
|
class Context:
|
||||||
|
names: Dict[Any, str] = dict() # 记录所有对象的名字,包括张量和算子
|
||||||
|
nodes: List[NodeProto] = [] # 保存所有算子
|
||||||
|
count_op: Dict[backend.OpType, int] = dict() # 统计每个算子出现次数,用于命名
|
||||||
|
count_in = 0 # 统计输入张量数量,用于命名
|
||||||
|
count_out = 0 # 统计输出张量数量,用于命名
|
||||||
|
|
||||||
|
def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]:
|
||||||
|
ty = op.op_type()
|
||||||
|
name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1)
|
||||||
|
self.names[op] = name
|
||||||
|
self.count_op[ty] += 1
|
||||||
|
return ty, name
|
||||||
|
|
||||||
|
def push_output(self, name: str, tensor: backend.Tensor) -> None:
|
||||||
|
self.names[tensor] = name
|
||||||
|
|
||||||
|
def push_input(self, tensor: backend.Tensor) -> str:
|
||||||
|
name = self.names.get(tensor)
|
||||||
|
if name is None:
|
||||||
|
self.count_in += 1
|
||||||
|
name = "input{}".format(self.count_in)
|
||||||
|
self.names[tensor] = name
|
||||||
|
return name
|
||||||
|
|
||||||
|
def push_node(self, node: NodeProto) -> None:
|
||||||
|
self.nodes.append(node)
|
||||||
|
|
||||||
|
# 拓扑排序
|
||||||
if not graph.topo_sort():
|
if not graph.topo_sort():
|
||||||
raise Exception("Sorting fails")
|
raise Exception("Sorting fails")
|
||||||
|
|
||||||
ops = graph.operators()
|
ops = graph.operators() # 图中所有算子(节点)
|
||||||
|
|
||||||
names: Dict[Any, str] = dict()
|
context = Context()
|
||||||
nodes: List[NodeProto] = []
|
|
||||||
count_op: Dict[backend.OpType, int] = dict()
|
|
||||||
count_in = 0
|
|
||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
ty = op.op_type()
|
ty, name = context.name_op(op)
|
||||||
name = "{}{}".format(ty.name, count_op.setdefault(ty, 0) + 1)
|
|
||||||
inputs = op.inputs()
|
inputs = op.inputs()
|
||||||
outputs = op.outputs()
|
outputs = op.outputs()
|
||||||
names[op] = name
|
|
||||||
count_op[ty] += 1
|
|
||||||
if ty == backend.OpType.Matmul:
|
if ty == backend.OpType.Matmul:
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.BatchNorm:
|
elif ty == backend.OpType.BatchNorm:
|
||||||
|
@ -331,71 +354,17 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.AvgPool:
|
elif ty == backend.OpType.AvgPool:
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.Add:
|
elif ty in [
|
||||||
names[outputs[0]] = name
|
backend.OpType.Add,
|
||||||
if inputs[0] in names:
|
backend.OpType.Sub,
|
||||||
a = names[inputs[0]]
|
backend.OpType.Mul,
|
||||||
else:
|
backend.OpType.Div,
|
||||||
count_in += 1
|
backend.OpType.Pow,
|
||||||
a = "input{}".format(count_in)
|
]:
|
||||||
if inputs[1] in names:
|
context.push_output(name, outputs[0])
|
||||||
b = names[inputs[1]]
|
a = context.push_input(inputs[0])
|
||||||
else:
|
b = context.push_input(inputs[1])
|
||||||
count_in += 1
|
context.push_node(make_node(ty.name, [a, b], [name], name))
|
||||||
b = "input{}".format(count_in)
|
|
||||||
nodes.append(make_node("Add", [a, b], [name], name))
|
|
||||||
elif ty == backend.OpType.Sub:
|
|
||||||
names[outputs[0]] = name
|
|
||||||
if inputs[0] in names:
|
|
||||||
a = names[inputs[0]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
a = "input{}".format(count_in)
|
|
||||||
if inputs[1] in names:
|
|
||||||
b = names[inputs[1]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
b = "input{}".format(count_in)
|
|
||||||
nodes.append(make_node("Sub", [a, b], [name], name))
|
|
||||||
elif ty == backend.OpType.Mul:
|
|
||||||
names[outputs[0]] = name
|
|
||||||
if inputs[0] in names:
|
|
||||||
a = names[inputs[0]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
a = "input{}".format(count_in)
|
|
||||||
if inputs[1] in names:
|
|
||||||
b = names[inputs[1]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
b = "input{}".format(count_in)
|
|
||||||
nodes.append(make_node("Mul", [a, b], [name], name))
|
|
||||||
elif ty == backend.OpType.Div:
|
|
||||||
names[outputs[0]] = name
|
|
||||||
if inputs[0] in names:
|
|
||||||
a = names[inputs[0]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
a = "input{}".format(count_in)
|
|
||||||
if inputs[1] in names:
|
|
||||||
b = names[inputs[1]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
b = "input{}".format(count_in)
|
|
||||||
nodes.append(make_node("Div", [a, b], [name], name))
|
|
||||||
elif ty == backend.OpType.Pow:
|
|
||||||
names[outputs[0]] = name
|
|
||||||
if inputs[0] in names:
|
|
||||||
a = names[inputs[0]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
a = "input{}".format(count_in)
|
|
||||||
if inputs[1] in names:
|
|
||||||
b = names[inputs[1]]
|
|
||||||
else:
|
|
||||||
count_in += 1
|
|
||||||
b = "input{}".format(count_in)
|
|
||||||
nodes.append(make_node("Pow", [a, b], [name], name))
|
|
||||||
elif ty == backend.OpType.Relu:
|
elif ty == backend.OpType.Relu:
|
||||||
raise Exception("TODO")
|
raise Exception("TODO")
|
||||||
elif ty == backend.OpType.Sigmoid:
|
elif ty == backend.OpType.Sigmoid:
|
||||||
|
@ -425,7 +394,7 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
else:
|
else:
|
||||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||||
|
|
||||||
print(names)
|
print(context.names)
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: ModelProto):
|
def parse_onnx(model: ModelProto):
|
||||||
|
|
|
@ -77,7 +77,7 @@ void init_graph_builder(py::module &m) {
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||||
m, "CpuRuntime");
|
m, "CpuRuntime");
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj")
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||||
.def("src", &TensorObj::getOutputOf, policy::move);
|
.def("src", &TensorObj::getOutputOf, policy::move);
|
||||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||||
|
|
Loading…
Reference in New Issue