diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index a9be7d1e..9ba2e456 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -85,7 +85,7 @@ class OnnxStub: while len(sorted_nodes) < len(model.graph.node): updated = False for i, node in enumerate(model.graph.node): - if all(t in known_edge for t in node.input): + if all(t in known_edge or t == "" for t in node.input): node.name = str(len(sorted_nodes)) + "_" + node.name sorted_nodes.append(i) known_edge.update(node.output) @@ -653,15 +653,15 @@ class OnnxStub: "nearest_mode", ] ) - if len(node.input) > 1: + if len(node.input) > 1 and node.input[1] in data: roiVal = _parse_data(data[node.input[1]]) else: roiVal = [] - if len(node.input) > 2: + if len(node.input) > 2 and node.input[2] in data: scalesVal = _parse_data(data[node.input[2]]) else: scalesVal = [] - if len(node.input) > 3: + if len(node.input) > 3 and node.input[3] in data: sizesVal = _parse_data(data[node.input[3]]) else: sizesVal = [] @@ -669,9 +669,21 @@ class OnnxStub: tensors[node.input[0]], output, axes, - tensors[node.input[3]] if len(node.input) > 3 else None, - tensors[node.input[2]] if len(node.input) > 2 else None, - tensors[node.input[1]] if len(node.input) > 1 else None, + ( + tensors[node.input[3]] + if len(node.input) > 3 and node.input[3] != "" + else None + ), + ( + tensors[node.input[2]] + if len(node.input) > 2 and node.input[2] != "" + else None + ), + ( + tensors[node.input[1]] + if len(node.input) > 1 and node.input[1] != "" + else None + ), sizesVal, scalesVal, roiVal,