feat: check everything

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-21 14:06:32 +08:00
parent 9e0f8f21bf
commit ffd0473bd2
1 changed files with 25 additions and 7 deletions

View File

@ -14,6 +14,13 @@ from onnx.helper import (
make_graph, make_graph,
make_model, make_model,
) )
from onnx.checker import (
check_graph,
check_model,
check_node,
check_value_info,
check_tensor,
)
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any, Tuple, Sequence from typing import Dict, List, Any, Tuple, Sequence
from functools import reduce from functools import reduce
@ -356,7 +363,9 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
self.names[tensor] = name self.names[tensor] = name
shape = tensor.shape() shape = tensor.shape()
dtype = backend.tensor_dtype(tensor) dtype = backend.tensor_dtype(tensor)
self.inputs.append(make_tensor_value_info(name, dtype, shape)) value_info = make_tensor_value_info(name, dtype, shape)
check_value_info(value_info)
self.inputs.append(value_info)
return name return name
@ -369,11 +378,16 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
vals: Any, vals: Any,
) -> str: ) -> str:
name = "{}_{}".format(node_name, attr_name) name = "{}_{}".format(node_name, attr_name)
self.inputs.append(make_tensor_value_info(name, elem_type, shape)) value_info = make_tensor_value_info(name, elem_type, shape)
self.initializers.append(make_tensor(name, elem_type, shape, vals)) tensor = make_tensor(name, elem_type, shape, vals)
check_value_info(value_info)
check_tensor(tensor)
self.inputs.append(value_info)
self.initializers.append(tensor)
return name return name
def push_node(self, node: NodeProto) -> None: def push_node(self, node: NodeProto) -> None:
check_node(node)
self.nodes.append(node) self.nodes.append(node)
def build(self, name: str) -> ModelProto: def build(self, name: str) -> ModelProto:
@ -386,11 +400,15 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
print() print()
print(ctx.nodes) print(ctx.nodes)
return make_model( graph = make_graph(
make_graph( self.nodes, name, self.inputs, self.outputs, self.initializers
self.nodes, name, self.inputs, self.outputs, self.initializers
)
) )
check_graph(graph)
model = make_model(graph)
check_model(model)
return model
# 拓扑排序 # 拓扑排序
if not graph.topo_sort(): if not graph.topo_sort():