diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py
index 81a2db64..03b6866f 100644
--- a/pyinfinitensor/src/pyinfinitensor/onnx.py
+++ b/pyinfinitensor/src/pyinfinitensor/onnx.py
@@ -1,6 +1,7 @@
 import onnx, backend
 from onnx.shape_inference import infer_shapes
 from typing import Dict, List, Any
+from functools import reduce
 
 runtime = backend.cpu_runtime()
 
@@ -143,21 +144,20 @@ def from_onnx(model: onnx.ModelProto):
                 (
                     value.type.tensor_type.shape
                     for value in model.graph.value_info
-                    if value.name == node.output[0]
+                    if value.name == node.input[0]
                 ),
                 None,
             ) or next(
-                output.type.tensor_type.shape
-                for output in model.graph.output
-                if output.name == node.output[0]
+                input.type.tensor_type.shape
+                for input in model.graph.input
+                if input.name == node.input[0]
             )
-            dims = _take_shape_dim(shape)
-
+            [_, _, h, w] = _take_shape_dim(shape)
             tensors[node.output[0]] = handler.avgPool(
                 tensors[node.input[0]],
                 tensors.get(node.output[0]),
-                dims[0],
-                dims[1],
+                h,
+                w,
                 1,
                 1,
                 0,
@@ -236,10 +236,31 @@ def from_onnx(model: onnx.ModelProto):
                 tensors.get(node.output[0]),
             )
         elif node.op_type == "Reshape":
+            input_shape = next(
+                (
+                    value.type.tensor_type.shape
+                    for value in model.graph.value_info
+                    if value.name == node.input[0]
+                ),
+                None,
+            ) or next(
+                input.type.tensor_type.shape
+                for input in model.graph.input
+                if input.name == node.input[0]
+            )
+            dims = _take_shape_dim(input_shape)
+            size = reduce(lambda acc, x: acc * x, dims)
+            output_shape = [int(i) for i in data[node.input[1]].int64_data]
+            for i, x in enumerate(output_shape):
+                if x == 0:
+                    output_shape[i] = dims[i]
+            temp = reduce(lambda acc, x: acc * x, output_shape)
+            if temp < 0:
+                output_shape[output_shape.index(-1)] = size // -temp
             tensors[node.output[0]] = handler.reshape(
                 tensors[node.input[0]],
                 tensors.get(node.output[0]),
-                [int(i) for i in data[node.input[1]].int64_data],
+                output_shape,
             )
         elif node.op_type == "Concat":
             tensors[node.output[0]] = handler.concat(