fix: 修正 GlobalAveragePool 和 Reshape 导入

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-23 08:59:06 +08:00
parent 4c7fdf44c5
commit d7e52054e6
1 changed files with 30 additions and 9 deletions

View File

@ -1,6 +1,7 @@
import onnx, backend
from onnx.shape_inference import infer_shapes
from typing import Dict, List, Any
from functools import reduce
runtime = backend.cpu_runtime()
@ -143,21 +144,20 @@ def from_onnx(model: onnx.ModelProto):
(
value.type.tensor_type.shape
for value in model.graph.value_info
if value.name == node.output[0]
if value.name == node.input[0]
),
None,
) or next(
output.type.tensor_type.shape
for output in model.graph.output
if output.name == node.output[0]
input.type.tensor_type.shape
for input in model.graph.input
if input.name == node.input[0]
)
dims = _take_shape_dim(shape)
[_, _, h, w] = _take_shape_dim(shape)
tensors[node.output[0]] = handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
dims[0],
dims[1],
h,
w,
1,
1,
0,
@ -236,10 +236,31 @@ def from_onnx(model: onnx.ModelProto):
tensors.get(node.output[0]),
)
elif node.op_type == "Reshape":
input_shape = next(
(
value.type.tensor_type.shape
for value in model.graph.value_info
if value.name == node.input[0]
),
None,
) or next(
input.type.tensor_type.shape
for input in model.graph.input
if input.name == node.input[0]
)
dims = _take_shape_dim(input_shape)
size = reduce(lambda acc, x: acc * x, dims)
output_shape = [int(i) for i in data[node.input[1]].int64_data]
for i, x in enumerate(output_shape):
if x == 0:
output_shape[i] = dims[i]
temp = reduce(lambda acc, x: acc * x, output_shape)
if temp < 0:
output_shape[output_shape.index(-1)] = size // -temp
tensors[node.output[0]] = handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
[int(i) for i in data[node.input[1]].int64_data],
output_shape,
)
elif node.op_type == "Concat":
tensors[node.output[0]] = handler.concat(