diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 81a2db64..03b6866f 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1,6 +1,7 @@ import onnx, backend from onnx.shape_inference import infer_shapes from typing import Dict, List, Any +from functools import reduce runtime = backend.cpu_runtime() @@ -143,21 +144,20 @@ def from_onnx(model: onnx.ModelProto): ( value.type.tensor_type.shape for value in model.graph.value_info - if value.name == node.output[0] + if value.name == node.input[0] ), None, ) or next( - output.type.tensor_type.shape - for output in model.graph.output - if output.name == node.output[0] + input.type.tensor_type.shape + for input in model.graph.input + if input.name == node.input[0] ) - dims = _take_shape_dim(shape) - + [_, _, h, w] = _take_shape_dim(shape) tensors[node.output[0]] = handler.avgPool( tensors[node.input[0]], tensors.get(node.output[0]), - dims[0], - dims[1], + h, + w, 1, 1, 0, @@ -236,10 +236,31 @@ def from_onnx(model: onnx.ModelProto): tensors.get(node.output[0]), ) elif node.op_type == "Reshape": + input_shape = next( + ( + value.type.tensor_type.shape + for value in model.graph.value_info + if value.name == node.input[0] + ), + None, + ) or next( + input.type.tensor_type.shape + for input in model.graph.input + if input.name == node.input[0] + ) + dims = _take_shape_dim(input_shape) + size = reduce(lambda acc, x: acc * x, dims) + output_shape = [int(i) for i in data[node.input[1]].int64_data] + for i, x in enumerate(output_shape): + if x == 0: + output_shape[i] = dims[i] + temp = reduce(lambda acc, x: acc * x, output_shape) + if temp < 0: + output_shape[output_shape.index(-1)] = size // -temp tensors[node.output[0]] = handler.reshape( tensors[node.input[0]], tensors.get(node.output[0]), - [int(i) for i in data[node.input[1]].int64_data], + output_shape, ) elif node.op_type == "Concat": tensors[node.output[0]] = handler.concat(