forked from jiuyuan/InfiniTensor
feat: 生成模型对象
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
6b7af7077b
commit
9e0f8f21bf
|
@ -7,7 +7,13 @@ from onnx import (
|
||||||
TensorShapeProto,
|
TensorShapeProto,
|
||||||
ValueInfoProto,
|
ValueInfoProto,
|
||||||
)
|
)
|
||||||
from onnx.helper import make_node, make_tensor_value_info, make_tensor
|
from onnx.helper import (
|
||||||
|
make_node,
|
||||||
|
make_tensor_value_info,
|
||||||
|
make_tensor,
|
||||||
|
make_graph,
|
||||||
|
make_model,
|
||||||
|
)
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from typing import Dict, List, Any, Tuple, Sequence
|
from typing import Dict, List, Any, Tuple, Sequence
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
@ -312,7 +318,7 @@ def from_onnx(model: ModelProto) -> backend.GraphHandler:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
|
||||||
def to_onnx(graph: backend.GraphHandler):
|
def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||||
class Context:
|
class Context:
|
||||||
# saves object names, including tensors and operators
|
# saves object names, including tensors and operators
|
||||||
names: Dict[Any, str] = dict()
|
names: Dict[Any, str] = dict()
|
||||||
|
@ -370,6 +376,22 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
def push_node(self, node: NodeProto) -> None:
|
def push_node(self, node: NodeProto) -> None:
|
||||||
self.nodes.append(node)
|
self.nodes.append(node)
|
||||||
|
|
||||||
|
def build(self, name: str) -> ModelProto:
|
||||||
|
print()
|
||||||
|
print(ctx.names)
|
||||||
|
print()
|
||||||
|
print(ctx.inputs)
|
||||||
|
print()
|
||||||
|
print(ctx.outputs)
|
||||||
|
print()
|
||||||
|
print(ctx.nodes)
|
||||||
|
|
||||||
|
return make_model(
|
||||||
|
make_graph(
|
||||||
|
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 拓扑排序
|
# 拓扑排序
|
||||||
if not graph.topo_sort():
|
if not graph.topo_sort():
|
||||||
raise Exception("Sorting fails")
|
raise Exception("Sorting fails")
|
||||||
|
@ -435,14 +457,7 @@ def to_onnx(graph: backend.GraphHandler):
|
||||||
else:
|
else:
|
||||||
raise Exception("Unsupported OpType {}".format(ty.name))
|
raise Exception("Unsupported OpType {}".format(ty.name))
|
||||||
|
|
||||||
print()
|
return ctx.build(name)
|
||||||
print(ctx.names)
|
|
||||||
print()
|
|
||||||
print(ctx.inputs)
|
|
||||||
print()
|
|
||||||
print(ctx.outputs)
|
|
||||||
print()
|
|
||||||
print(ctx.nodes)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_onnx(model: ModelProto):
|
def parse_onnx(model: ModelProto):
|
||||||
|
|
|
@ -309,7 +309,7 @@ class TestStringMethods(unittest.TestCase):
|
||||||
handler.add(abc, d, abcd)
|
handler.add(abc, d, abcd)
|
||||||
handler.add(abcd, e, abcde)
|
handler.add(abcd, e, abcde)
|
||||||
|
|
||||||
to_onnx(handler)
|
to_onnx(handler, "add")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue