forked from jiuyuan/InfiniTensor
fix: 修正 GlobalAveragePool 和 Reshape 导入
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
4c7fdf44c5
commit
d7e52054e6
|
@ -1,6 +1,7 @@
|
||||||
import onnx, backend
|
import onnx, backend
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
runtime = backend.cpu_runtime()
|
runtime = backend.cpu_runtime()
|
||||||
|
|
||||||
|
@ -143,21 +144,20 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
(
|
(
|
||||||
value.type.tensor_type.shape
|
value.type.tensor_type.shape
|
||||||
for value in model.graph.value_info
|
for value in model.graph.value_info
|
||||||
if value.name == node.output[0]
|
if value.name == node.input[0]
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
) or next(
|
) or next(
|
||||||
output.type.tensor_type.shape
|
input.type.tensor_type.shape
|
||||||
for output in model.graph.output
|
for input in model.graph.input
|
||||||
if output.name == node.output[0]
|
if input.name == node.input[0]
|
||||||
)
|
)
|
||||||
dims = _take_shape_dim(shape)
|
[_, _, h, w] = _take_shape_dim(shape)
|
||||||
|
|
||||||
tensors[node.output[0]] = handler.avgPool(
|
tensors[node.output[0]] = handler.avgPool(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
dims[0],
|
h,
|
||||||
dims[1],
|
w,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
0,
|
0,
|
||||||
|
@ -236,10 +236,31 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Reshape":
|
elif node.op_type == "Reshape":
|
||||||
|
input_shape = next(
|
||||||
|
(
|
||||||
|
value.type.tensor_type.shape
|
||||||
|
for value in model.graph.value_info
|
||||||
|
if value.name == node.input[0]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
) or next(
|
||||||
|
input.type.tensor_type.shape
|
||||||
|
for input in model.graph.input
|
||||||
|
if input.name == node.input[0]
|
||||||
|
)
|
||||||
|
dims = _take_shape_dim(input_shape)
|
||||||
|
size = reduce(lambda acc, x: acc * x, dims)
|
||||||
|
output_shape = [int(i) for i in data[node.input[1]].int64_data]
|
||||||
|
for i, x in enumerate(output_shape):
|
||||||
|
if x == 0:
|
||||||
|
output_shape[i] = dims[i]
|
||||||
|
temp = reduce(lambda acc, x: acc * x, output_shape)
|
||||||
|
if temp < 0:
|
||||||
|
output_shape[output_shape.index(-1)] = size // -temp
|
||||||
tensors[node.output[0]] = handler.reshape(
|
tensors[node.output[0]] = handler.reshape(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
[int(i) for i in data[node.input[1]].int64_data],
|
output_shape,
|
||||||
)
|
)
|
||||||
elif node.op_type == "Concat":
|
elif node.op_type == "Concat":
|
||||||
tensors[node.output[0]] = handler.concat(
|
tensors[node.output[0]] = handler.concat(
|
||||||
|
|
Loading…
Reference in New Issue