forked from jiuyuan/InfiniTensor
feat: check everything
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
9e0f8f21bf
commit
ffd0473bd2
|
@ -14,6 +14,13 @@ from onnx.helper import (
|
||||||
make_graph,
|
make_graph,
|
||||||
make_model,
|
make_model,
|
||||||
)
|
)
|
||||||
|
from onnx.checker import (
|
||||||
|
check_graph,
|
||||||
|
check_model,
|
||||||
|
check_node,
|
||||||
|
check_value_info,
|
||||||
|
check_tensor,
|
||||||
|
)
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from typing import Dict, List, Any, Tuple, Sequence
|
from typing import Dict, List, Any, Tuple, Sequence
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
@ -356,7 +363,9 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||||
self.names[tensor] = name
|
self.names[tensor] = name
|
||||||
shape = tensor.shape()
|
shape = tensor.shape()
|
||||||
dtype = backend.tensor_dtype(tensor)
|
dtype = backend.tensor_dtype(tensor)
|
||||||
self.inputs.append(make_tensor_value_info(name, dtype, shape))
|
value_info = make_tensor_value_info(name, dtype, shape)
|
||||||
|
check_value_info(value_info)
|
||||||
|
self.inputs.append(value_info)
|
||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
@ -369,11 +378,16 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||||
vals: Any,
|
vals: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
name = "{}_{}".format(node_name, attr_name)
|
name = "{}_{}".format(node_name, attr_name)
|
||||||
self.inputs.append(make_tensor_value_info(name, elem_type, shape))
|
value_info = make_tensor_value_info(name, elem_type, shape)
|
||||||
self.initializers.append(make_tensor(name, elem_type, shape, vals))
|
tensor = make_tensor(name, elem_type, shape, vals)
|
||||||
|
check_value_info(value_info)
|
||||||
|
check_tensor(tensor)
|
||||||
|
self.inputs.append(value_info)
|
||||||
|
self.initializers.append(tensor)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def push_node(self, node: NodeProto) -> None:
|
def push_node(self, node: NodeProto) -> None:
|
||||||
|
check_node(node)
|
||||||
self.nodes.append(node)
|
self.nodes.append(node)
|
||||||
|
|
||||||
def build(self, name: str) -> ModelProto:
|
def build(self, name: str) -> ModelProto:
|
||||||
|
@ -386,11 +400,15 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
||||||
print()
|
print()
|
||||||
print(ctx.nodes)
|
print(ctx.nodes)
|
||||||
|
|
||||||
return make_model(
|
graph = make_graph(
|
||||||
make_graph(
|
|
||||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||||
)
|
)
|
||||||
)
|
check_graph(graph)
|
||||||
|
|
||||||
|
model = make_model(graph)
|
||||||
|
check_model(model)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
# 拓扑排序
|
# 拓扑排序
|
||||||
if not graph.topo_sort():
|
if not graph.topo_sort():
|
||||||
|
|
Loading…
Reference in New Issue