Add: shape of intermediate tensor in exported ONNX

This commit is contained in:
Liyan Zheng 2023-04-20 10:28:30 +08:00
parent 34ca6bf149
commit 2a343e240e
1 changed files with 13 additions and 10 deletions

View File

@ -1,4 +1,4 @@
import backend
import backend
from onnx import (
ModelProto,
TensorProto,
@ -538,9 +538,6 @@ class OnnxStub:
else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
for output in model.graph.output:
self.outputs[output.name] = tensors[output.name]
return ans
@classmethod
@ -571,6 +568,8 @@ class OnnxStub:
outputs: List[ValueInfoProto] = []
# saves global input tensors
initializers: List[TensorProto] = []
# saves global output tensors
value_info: List[ValueInfoProto] = []
enable_check = False
def __init__(self, enable_check):
@ -586,12 +585,15 @@ class OnnxStub:
def push_output(self, name: str, tensor: backend.Tensor) -> str:
self.names[tensor] = name
if not tensor.has_target():
shape = tensor.shape()
dtype = backend.tensor_dtype(tensor)
value_info = make_tensor_value_info(name, dtype, shape)
check_value_info(value_info)
shape = tensor.shape()
dtype = backend.tensor_dtype(tensor)
value_info = make_tensor_value_info(name, dtype, shape)
check_value_info(value_info)
if not tensor.has_target(): # if this output is a global output
self.outputs.append(value_info)
else: # if this output is a local output
self.value_info.append(value_info)
return name
def push_input(
@ -635,7 +637,8 @@ class OnnxStub:
def build(self, name: str) -> ModelProto:
graph = make_graph(
self.nodes, name, self.inputs, self.outputs, self.initializers
self.nodes, name, self.inputs, self.outputs, self.initializers,
value_info=self.value_info
)
if self.enable_check:
check_graph(graph)