diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 03339455..802fdb15 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -7,7 +7,13 @@ from onnx import ( TensorShapeProto, ValueInfoProto, ) -from onnx.helper import make_node, make_tensor_value_info, make_tensor +from onnx.helper import ( + make_node, + make_tensor_value_info, + make_tensor, + make_graph, + make_model, +) from onnx.shape_inference import infer_shapes from typing import Dict, List, Any, Tuple, Sequence from functools import reduce @@ -312,7 +318,7 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler: raise Exception('Unsupported operator "{}"'.format(node.op_type)) -def to_onnx(graph: backend.GraphHandler): +def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: class Context: # saves object names, including tensors and operators names: Dict[Any, str] = dict() @@ -370,6 +376,22 @@ def to_onnx(graph: backend.GraphHandler): def push_node(self, node: NodeProto) -> None: self.nodes.append(node) + def build(self, name: str) -> ModelProto: + print() + print(ctx.names) + print() + print(ctx.inputs) + print() + print(ctx.outputs) + print() + print(ctx.nodes) + + return make_model( + make_graph( + self.nodes, name, self.inputs, self.outputs, self.initializers + ) + ) + # ζ‹“ζ‰‘ζŽ’εΊ if not graph.topo_sort(): raise Exception("Sorting fails") @@ -435,14 +457,7 @@ def to_onnx(graph: backend.GraphHandler): else: raise Exception("Unsupported OpType {}".format(ty.name)) - print() - print(ctx.names) - print() - print(ctx.inputs) - print() - print(ctx.outputs) - print() - print(ctx.nodes) + return ctx.build(name) def parse_onnx(model: ModelProto): diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index a28dd5b9..d512b504 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -309,7 +309,7 @@ class TestStringMethods(unittest.TestCase): handler.add(abc, d, abcd) handler.add(abcd, e, abcde) - to_onnx(handler) + to_onnx(handler, "add") if __name__ == "__main__":