fix: 改正 reshap 导入

- 从 initializer 拿到 reshape 的 shape 值
- 但 reshape 仍然无法导入,因为无法分辨 shape 其实不是一个后端张量

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-14 10:14:55 +08:00
parent 7626efbfa8
commit d9e2953425
2 changed files with 16 additions and 10 deletions

View File

@ -7,6 +7,7 @@ def from_onnx(model: onnx.ModelProto):
handler = backend.GraphHandlerObj(runtime)
tensors = dict()
data = dict()
for input in model.graph.input:
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
@ -16,6 +17,9 @@ def from_onnx(model: onnx.ModelProto):
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
for initializer in model.graph.initializer:
data[initializer.name] = initializer
for node in model.graph.node:
if node.op_type == "MatMul":
tensors[node.output[0]] = handler.matmul(
@ -115,7 +119,7 @@ def from_onnx(model: onnx.ModelProto):
tensors[node.output[0]] = handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
[int(i) for i in tensors[node.input[1]]],
data[node.input[1]].int32_data,
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))

View File

@ -129,15 +129,17 @@ class TestStringMethods(unittest.TestCase):
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
# FIXME INT64 类型不支持
# def test_reshape(self):
# data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
# shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
# reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
# reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
# make_and_import_model(
# make_graph([reshape], "reshape", [data, shape], [reshaped])
# )
def test_reshape(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
# FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。
# tensor 的类型又不支持 INT64所以这里会报一个错。
# 如何分辨 onnx 的张量是不是需要作为张量注册?
# make_and_import_model(
# make_graph([reshape], "reshape", [data, shape], [reshaped])
# )
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
def test_linear(self):