diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 802fdb15..1008b391 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -14,6 +14,13 @@ from onnx.helper import ( make_graph, make_model, ) +from onnx.checker import ( + check_graph, + check_model, + check_node, + check_value_info, + check_tensor, +) from onnx.shape_inference import infer_shapes from typing import Dict, List, Any, Tuple, Sequence from functools import reduce @@ -356,7 +363,9 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: self.names[tensor] = name shape = tensor.shape() dtype = backend.tensor_dtype(tensor) - self.inputs.append(make_tensor_value_info(name, dtype, shape)) + value_info = make_tensor_value_info(name, dtype, shape) + check_value_info(value_info) + self.inputs.append(value_info) return name @@ -369,11 +378,16 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: vals: Any, ) -> str: name = "{}_{}".format(node_name, attr_name) - self.inputs.append(make_tensor_value_info(name, elem_type, shape)) - self.initializers.append(make_tensor(name, elem_type, shape, vals)) + value_info = make_tensor_value_info(name, elem_type, shape) + tensor = make_tensor(name, elem_type, shape, vals) + check_value_info(value_info) + check_tensor(tensor) + self.inputs.append(value_info) + self.initializers.append(tensor) return name def push_node(self, node: NodeProto) -> None: + check_node(node) self.nodes.append(node) def build(self, name: str) -> ModelProto: @@ -386,11 +400,15 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: print() print(ctx.nodes) - return make_model( - make_graph( - self.nodes, name, self.inputs, self.outputs, self.initializers - ) + graph = make_graph( + self.nodes, name, self.inputs, self.outputs, self.initializers ) + check_graph(graph) + + model = make_model(graph) + check_model(model) + + return model # ζ‹“ζ‰‘ζŽ’εΊ if not graph.topo_sort():