feat: 生成模型对象

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-21 14:01:57 +08:00
parent 6b7af7077b
commit 9e0f8f21bf
2 changed files with 26 additions and 11 deletions

View File

@ -7,7 +7,13 @@ from onnx import (
TensorShapeProto,
ValueInfoProto,
)
from onnx.helper import make_node, make_tensor_value_info, make_tensor
from onnx.helper import (
make_node,
make_tensor_value_info,
make_tensor,
make_graph,
make_model,
)
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any, Tuple, Sequence
from functools import reduce
@ -312,7 +318,7 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
def to_onnx(graph: backend.GraphHandler):
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
class Context:
# saves object names, including tensors and operators
names: Dict[Any, str] = dict()
@ -370,6 +376,22 @@ def to_onnx(graph: backend.GraphHandler):
def push_node(self, node: NodeProto) -> None:
self.nodes.append(node)
def build(self, name: str) -> ModelProto:
print()
print(ctx.names)
print()
print(ctx.inputs)
print()
print(ctx.outputs)
print()
print(ctx.nodes)
return make_model(
make_graph(
self.nodes, name, self.inputs, self.outputs, self.initializers
)
)
# 拓扑排序
if not graph.topo_sort():
raise Exception("Sorting fails")
@ -435,14 +457,7 @@ def to_onnx(graph: backend.GraphHandler):
else:
raise Exception("Unsupported OpType {}".format(ty.name))
print()
print(ctx.names)
print()
print(ctx.inputs)
print()
print(ctx.outputs)
print()
print(ctx.nodes)
return ctx.build(name)
def parse_onnx(model: ModelProto):

View File

@ -309,7 +309,7 @@ class TestStringMethods(unittest.TestCase):
handler.add(abc, d, abcd)
handler.add(abcd, e, abcde)
to_onnx(handler)
to_onnx(handler, "add")
if __name__ == "__main__":