From 2a343e240ed972b2231a6ec363091b3bb3fd8d5a Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Thu, 20 Apr 2023 10:28:30 +0800 Subject: [PATCH] Add: shape of intermediate tensor in exported ONNX --- pyinfinitensor/src/pyinfinitensor/onnx.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a5248b28..7fdcfc5f 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,4 +1,4 @@ -import backend +import backend from onnx import ( ModelProto, TensorProto, @@ -538,9 +538,6 @@ class OnnxStub: else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) - for output in model.graph.output: - self.outputs[output.name] = tensors[output.name] - return ans @classmethod @@ -571,6 +568,8 @@ class OnnxStub: outputs: List[ValueInfoProto] = [] # saves global input tensors initializers: List[TensorProto] = [] + # saves global output tensors + value_info: List[ValueInfoProto] = [] enable_check = False def __init__(self, enable_check): @@ -586,12 +585,15 @@ class OnnxStub: def push_output(self, name: str, tensor: backend.Tensor) -> str: self.names[tensor] = name - if not tensor.has_target(): - shape = tensor.shape() - dtype = backend.tensor_dtype(tensor) - value_info = make_tensor_value_info(name, dtype, shape) - check_value_info(value_info) + + shape = tensor.shape() + dtype = backend.tensor_dtype(tensor) + value_info = make_tensor_value_info(name, dtype, shape) + check_value_info(value_info) + if not tensor.has_target(): # if this output is a global output self.outputs.append(value_info) + else: # if this output is a local output + self.value_info.append(value_info) return name def push_input( @@ -635,7 +637,8 @@ class OnnxStub: def build(self, name: str) -> ModelProto: graph = make_graph( - self.nodes, name, self.inputs, self.outputs, self.initializers + self.nodes, name, self.inputs, self.outputs, self.initializers, + value_info=self.value_info ) if self.enable_check: check_graph(graph)