fix: 为中间结果提供tensor到node的mapping

This commit is contained in:
panzezhong 2023-12-27 10:48:06 +08:00
parent c34946a0d8
commit 85de28ef1e
1 changed files with 3 additions and 0 deletions

View File

@ -52,6 +52,7 @@ class OnnxStub:
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.tensors: Dict[str, backend.Tensor] = {}
self.tensor_node_map: Dict[str, str] = {}
self.initializer: Dict[int, TensorProto] = {}
try:
model = infer_shapes(model)
@ -80,6 +81,8 @@ class OnnxStub:
node.name = str(len(sorted_nodes)) + "_" + node.name
sorted_nodes.append(i)
known_edge.update(node.output)
for t_ in node.output:
self.tensor_node_map[t_] = node.name
updated = True
if not updated:
raise Exception("Graph has cycle")