From 377b3bf391fac8c3cd2cd97cdea16d83a2da2185 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Tue, 30 Apr 2024 10:53:40 +0800 Subject: [PATCH] fix: onnx resize op input is none bug --- pyinfinitensor/src/pyinfinitensor/onnx.py | 26 +++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a9be7d1e..9ba2e456 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -85,7 +85,7 @@ class OnnxStub: while len(sorted_nodes) < len(model.graph.node): updated = False for i, node in enumerate(model.graph.node): - if all(t in known_edge for t in node.input): + if all(t in known_edge or t == "" for t in node.input): node.name = str(len(sorted_nodes)) + "_" + node.name sorted_nodes.append(i) known_edge.update(node.output) @@ -653,15 +653,15 @@ class OnnxStub: "nearest_mode", ] ) - if len(node.input) > 1: + if len(node.input) > 1 and node.input[1] in data: roiVal = _parse_data(data[node.input[1]]) else: roiVal = [] - if len(node.input) > 2: + if len(node.input) > 2 and node.input[2] in data: scalesVal = _parse_data(data[node.input[2]]) else: scalesVal = [] - if len(node.input) > 3: + if len(node.input) > 3 and node.input[3] in data: sizesVal = _parse_data(data[node.input[3]]) else: sizesVal = [] @@ -669,9 +669,21 @@ class OnnxStub: tensors[node.input[0]], output, axes, - tensors[node.input[3]] if len(node.input) > 3 else None, - tensors[node.input[2]] if len(node.input) > 2 else None, - tensors[node.input[1]] if len(node.input) > 1 else None, + ( + tensors[node.input[3]] + if len(node.input) > 3 and node.input[3] != "" + else None + ), + ( + tensors[node.input[2]] + if len(node.input) > 2 and node.input[2] != "" + else None + ), + ( + tensors[node.input[1]] + if len(node.input) > 1 and node.input[1] != "" + else None + ), sizesVal, scalesVal, roiVal,