forked from jiuyuan/InfiniTensor
fix: 修正 GlobalAveragePool 和 Reshape 导入
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
4c7fdf44c5
commit
d7e52054e6
|
@ -1,6 +1,7 @@
|
|||
import onnx, backend
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from typing import Dict, List, Any
|
||||
from functools import reduce
|
||||
|
||||
runtime = backend.cpu_runtime()
|
||||
|
||||
|
@ -143,21 +144,20 @@ def from_onnx(model: onnx.ModelProto):
|
|||
(
|
||||
value.type.tensor_type.shape
|
||||
for value in model.graph.value_info
|
||||
if value.name == node.output[0]
|
||||
if value.name == node.input[0]
|
||||
),
|
||||
None,
|
||||
) or next(
|
||||
output.type.tensor_type.shape
|
||||
for output in model.graph.output
|
||||
if output.name == node.output[0]
|
||||
input.type.tensor_type.shape
|
||||
for input in model.graph.input
|
||||
if input.name == node.input[0]
|
||||
)
|
||||
dims = _take_shape_dim(shape)
|
||||
|
||||
[_, _, h, w] = _take_shape_dim(shape)
|
||||
tensors[node.output[0]] = handler.avgPool(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
dims[0],
|
||||
dims[1],
|
||||
h,
|
||||
w,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
|
@ -236,10 +236,31 @@ def from_onnx(model: onnx.ModelProto):
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Reshape":
|
||||
input_shape = next(
|
||||
(
|
||||
value.type.tensor_type.shape
|
||||
for value in model.graph.value_info
|
||||
if value.name == node.input[0]
|
||||
),
|
||||
None,
|
||||
) or next(
|
||||
input.type.tensor_type.shape
|
||||
for input in model.graph.input
|
||||
if input.name == node.input[0]
|
||||
)
|
||||
dims = _take_shape_dim(input_shape)
|
||||
size = reduce(lambda acc, x: acc * x, dims)
|
||||
output_shape = [int(i) for i in data[node.input[1]].int64_data]
|
||||
for i, x in enumerate(output_shape):
|
||||
if x == 0:
|
||||
output_shape[i] = dims[i]
|
||||
temp = reduce(lambda acc, x: acc * x, output_shape)
|
||||
if temp < 0:
|
||||
output_shape[output_shape.index(-1)] = size // -temp
|
||||
tensors[node.output[0]] = handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
[int(i) for i in data[node.input[1]].int64_data],
|
||||
output_shape,
|
||||
)
|
||||
elif node.op_type == "Concat":
|
||||
tensors[node.output[0]] = handler.concat(
|
||||
|
|
Loading…
Reference in New Issue