forked from jiuyuan/InfiniTensor
Add: shape of intermediate tensor in exported ONNX
This commit is contained in:
parent
34ca6bf149
commit
2a343e240e
|
@ -1,4 +1,4 @@
|
|||
import backend
|
||||
import backend
|
||||
from onnx import (
|
||||
ModelProto,
|
||||
TensorProto,
|
||||
|
@ -538,9 +538,6 @@ class OnnxStub:
|
|||
else:
|
||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||
|
||||
for output in model.graph.output:
|
||||
self.outputs[output.name] = tensors[output.name]
|
||||
|
||||
return ans
|
||||
|
||||
@classmethod
|
||||
|
@ -571,6 +568,8 @@ class OnnxStub:
|
|||
outputs: List[ValueInfoProto] = []
|
||||
# saves global input tensors
|
||||
initializers: List[TensorProto] = []
|
||||
# saves global output tensors
|
||||
value_info: List[ValueInfoProto] = []
|
||||
|
||||
enable_check = False
|
||||
def __init__(self, enable_check):
|
||||
|
@ -586,12 +585,15 @@ class OnnxStub:
|
|||
|
||||
def push_output(self, name: str, tensor: backend.Tensor) -> str:
|
||||
self.names[tensor] = name
|
||||
if not tensor.has_target():
|
||||
|
||||
shape = tensor.shape()
|
||||
dtype = backend.tensor_dtype(tensor)
|
||||
value_info = make_tensor_value_info(name, dtype, shape)
|
||||
check_value_info(value_info)
|
||||
if not tensor.has_target(): # if this output is a global output
|
||||
self.outputs.append(value_info)
|
||||
else: # if this output is a local output
|
||||
self.value_info.append(value_info)
|
||||
return name
|
||||
|
||||
def push_input(
|
||||
|
@ -635,7 +637,8 @@ class OnnxStub:
|
|||
|
||||
def build(self, name: str) -> ModelProto:
|
||||
graph = make_graph(
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers
|
||||
self.nodes, name, self.inputs, self.outputs, self.initializers,
|
||||
value_info=self.value_info
|
||||
)
|
||||
if self.enable_check:
|
||||
check_graph(graph)
|
||||
|
|
Loading…
Reference in New Issue