forked from jiuyuan/InfiniTensor
feat: 生成模型对象
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
6b7af7077b
commit
9e0f8f21bf
|
@ -7,7 +7,13 @@ from onnx import (
|
|||
TensorShapeProto,
|
||||
ValueInfoProto,
|
||||
)
|
||||
from onnx.helper import make_node, make_tensor_value_info, make_tensor
|
||||
from onnx.helper import (
|
||||
make_node,
|
||||
make_tensor_value_info,
|
||||
make_tensor,
|
||||
make_graph,
|
||||
make_model,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any, Tuple, Sequence
|
||||
from functools import reduce
|
||||
|
@ -312,7 +318,7 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
|||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
||||
|
||||
def to_onnx(graph: backend.GraphHandler):
|
||||
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||
class Context:
|
||||
# saves object names, including tensors and operators
|
||||
names: Dict[Any, str] = dict()
|
||||
|
@ -370,6 +376,22 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
def push_node(self, node: NodeProto) -> None:
|
||||
self.nodes.append(node)
|
||||
|
||||
def build(self, name: str) -> ModelProto:
|
||||
print()
|
||||
print(ctx.names)
|
||||
print()
|
||||
print(ctx.inputs)
|
||||
print()
|
||||
print(ctx.outputs)
|
||||
print()
|
||||
print(ctx.nodes)
|
||||
|
||||
return make_model(
|
||||
make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
)
|
||||
)
|
||||
|
||||
# 拓扑排序
|
||||
if not graph.topo_sort():
|
||||
raise Exception("Sorting fails")
|
||||
|
@ -435,14 +457,7 @@ def to_onnx(graph: backend.GraphHandler):
|
|||
else:
|
||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||
|
||||
print()
|
||||
print(ctx.names)
|
||||
print()
|
||||
print(ctx.inputs)
|
||||
print()
|
||||
print(ctx.outputs)
|
||||
print()
|
||||
print(ctx.nodes)
|
||||
return ctx.build(name)
|
||||
|
||||
|
||||
def parse_onnx(model: ModelProto):
|
||||
|
|
|
@ -309,7 +309,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
handler.add(abc, d, abcd)
|
||||
handler.add(abcd, e, abcde)
|
||||
|
||||
to_onnx(handler)
|
||||
to_onnx(handler, "add")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue