forked from jiuyuan/InfiniTensor
feat: check everything
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
9e0f8f21bf
commit
ffd0473bd2
|
@ -14,6 +14,13 @@ from onnx.helper import (
|
|||
make_graph,
|
||||
make_model,
|
||||
)
|
||||
from onnx.checker import (
|
||||
check_graph,
|
||||
check_model,
|
||||
check_node,
|
||||
check_value_info,
|
||||
check_tensor,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any, Tuple, Sequence
|
||||
from functools import reduce
|
||||
|
@ -356,7 +363,9 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
self.names[tensor] = name
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
self.inputs.append(make_tensor_value_info(name, dtype, shape))
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
self.inputs.append(value_info)
|
||||
|
||||
return name
|
||||
|
||||
|
@ -369,11 +378,16 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
vals: Any,
|
||||
) -> str:
|
||||
name = "{}_{}".format(node_name, attr_name)
|
||||
self.inputs.append(make_tensor_value_info(name, elem_type, shape))
|
||||
self.initializers.append(make_tensor(name, elem_type, shape, vals))
|
||||
value_info = make_tensor_value_info(name, elem_type, shape)
|
||||
tensor = make_tensor(name, elem_type, shape, vals)
|
||||
check_value_info(value_info)
|
||||
check_tensor(tensor)
|
||||
self.inputs.append(value_info)
|
||||
self.initializers.append(tensor)
|
||||
return name
|
||||
|
||||
def push_node(self, node: NodeProto) -> None:
|
||||
check_node(node)
|
||||
self.nodes.append(node)
|
||||
|
||||
def build(self, name: str) -> ModelProto:
|
||||
|
@ -386,11 +400,15 @@ def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto:
|
|||
print()
|
||||
print(ctx.nodes)
|
||||
|
||||
return make_model(
|
||||
make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
)
|
||||
graph = make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
)
|
||||
check_graph(graph)
|
||||
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
|
||||
return model
|
||||
|
||||
# 拓扑排序
|
||||
if not graph.topo_sort():
|
||||
|
|
Loading…
Reference in New Issue