diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index c98fa7cf..8322ea26 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -2,7 +2,7 @@ from onnx import ModelProto, TensorProto, NodeProto, AttributeProto, TensorShapeProto from onnx.helper import make_node from onnx.shape_inference import infer_shapes -from typing import Dict, List, Any +from typing import Dict, List, Any, Tuple from functools import reduce runtime = backend.cpu_runtime() @@ -306,23 +306,46 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler: def to_onnx(graph: backend.GraphHandler): + class Context: + names: Dict[Any, str] = dict() # 记录所有对象的名字,包括张量和算子 + nodes: List[NodeProto] = [] # 保存所有算子 + count_op: Dict[backend.OpType, int] = dict() # 统计每个算子出现次数,用于命名 + count_in = 0 # 统计输入张量数量,用于命名 + count_out = 0 # 统计输出张量数量,用于命名 + + def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]: + ty = op.op_type() + name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1) + self.names[op] = name + self.count_op[ty] += 1 + return ty, name + + def push_output(self, name: str, tensor: backend.Tensor) -> None: + self.names[tensor] = name + + def push_input(self, tensor: backend.Tensor) -> str: + name = self.names.get(tensor) + if name is None: + self.count_in += 1 + name = "input{}".format(self.count_in) + self.names[tensor] = name + return name + + def push_node(self, node: NodeProto) -> None: + self.nodes.append(node) + + # 拓扑排序 if not graph.topo_sort(): raise Exception("Sorting fails") - ops = graph.operators() + ops = graph.operators() # 图中所有算子(节点) - names: Dict[Any, str] = dict() - nodes: List[NodeProto] = [] - count_op: Dict[backend.OpType, int] = dict() - count_in = 0 + context = Context() for op in ops: - ty = op.op_type() - name = "{}{}".format(ty.name, count_op.setdefault(ty, 0) + 1) + ty, name = context.name_op(op) inputs = op.inputs() outputs = op.outputs() - names[op] = name - count_op[ty] += 1 if ty == backend.OpType.Matmul: raise Exception("TODO") elif ty == backend.OpType.BatchNorm: @@ -331,71 +354,17 @@ def to_onnx(graph: backend.GraphHandler): raise Exception("TODO") elif ty == backend.OpType.AvgPool: raise Exception("TODO") - elif ty == backend.OpType.Add: - names[outputs[0]] = name - if inputs[0] in names: - a = names[inputs[0]] - else: - count_in += 1 - a = "input{}".format(count_in) - if inputs[1] in names: - b = names[inputs[1]] - else: - count_in += 1 - b = "input{}".format(count_in) - nodes.append(make_node("Add", [a, b], [name], name)) - elif ty == backend.OpType.Sub: - names[outputs[0]] = name - if inputs[0] in names: - a = names[inputs[0]] - else: - count_in += 1 - a = "input{}".format(count_in) - if inputs[1] in names: - b = names[inputs[1]] - else: - count_in += 1 - b = "input{}".format(count_in) - nodes.append(make_node("Sub", [a, b], [name], name)) - elif ty == backend.OpType.Mul: - names[outputs[0]] = name - if inputs[0] in names: - a = names[inputs[0]] - else: - count_in += 1 - a = "input{}".format(count_in) - if inputs[1] in names: - b = names[inputs[1]] - else: - count_in += 1 - b = "input{}".format(count_in) - nodes.append(make_node("Mul", [a, b], [name], name)) - elif ty == backend.OpType.Div: - names[outputs[0]] = name - if inputs[0] in names: - a = names[inputs[0]] - else: - count_in += 1 - a = "input{}".format(count_in) - if inputs[1] in names: - b = names[inputs[1]] - else: - count_in += 1 - b = "input{}".format(count_in) - nodes.append(make_node("Div", [a, b], [name], name)) - elif ty == backend.OpType.Pow: - names[outputs[0]] = name - if inputs[0] in names: - a = names[inputs[0]] - else: - count_in += 1 - a = "input{}".format(count_in) - if inputs[1] in names: - b = names[inputs[1]] - else: - count_in += 1 - b = "input{}".format(count_in) - nodes.append(make_node("Pow", [a, b], [name], name)) + elif ty in [ + backend.OpType.Add, + backend.OpType.Sub, + backend.OpType.Mul, + backend.OpType.Div, + backend.OpType.Pow, + ]: + context.push_output(name, outputs[0]) + a = context.push_input(inputs[0]) + b = context.push_input(inputs[1]) + context.push_node(make_node(ty.name, [a, b], [name], name)) elif ty == backend.OpType.Relu: raise Exception("TODO") elif ty == backend.OpType.Sigmoid: @@ -425,7 +394,7 @@ def to_onnx(graph: backend.GraphHandler): else: raise Exception("Unsupported OpType {}".format(ty.name)) - print(names) + print(context.names) def parse_onnx(model: ModelProto): diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index ceb641d7..45293576 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -77,7 +77,7 @@ void init_graph_builder(py::module &m) { py::class_>(m, "Runtime"); py::class_, RuntimeObj>( m, "CpuRuntime"); - py::class_>(m, "TensorObj") + py::class_>(m, "Tensor") .def("src", &TensorObj::getOutputOf, policy::move); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic)