diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index b4a8f6a1..453185ae 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -7,6 +7,7 @@ def from_onnx(model: onnx.ModelProto): handler = backend.GraphHandlerObj(runtime) tensors = dict() + data = dict() for input in model.graph.input: dims = [d.dim_value for d in input.type.tensor_type.shape.dim] @@ -16,6 +17,9 @@ def from_onnx(model: onnx.ModelProto): dims = [d.dim_value for d in output.type.tensor_type.shape.dim] tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type) + for initializer in model.graph.initializer: + data[initializer.name] = initializer + for node in model.graph.node: if node.op_type == "MatMul": tensors[node.output[0]] = handler.matmul( @@ -115,7 +119,7 @@ def from_onnx(model: onnx.ModelProto): tensors[node.output[0]] = handler.reshape( tensors[node.input[0]], tensors.get(node.output[0]), - [int(i) for i in tensors[node.input[1]]], + data[node.input[1]].int32_data, ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index e617d519..092b1e05 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -129,15 +129,17 @@ class TestStringMethods(unittest.TestCase): flatten = make_node("Flatten", ["x"], ["y"], name="flatten") make_and_import_model(make_graph([flatten], "flatten", [x], [y])) - # FIXME INT64 类型不支持 - # def test_reshape(self): - # data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) - # shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8]) - # reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8]) - # reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape") - # make_and_import_model( - # make_graph([reshape], "reshape", [data, shape], [reshaped]) - # ) + def test_reshape(self): + data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) + shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8]) + reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8]) + reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape") + # FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。 + # tensor 的类型又不支持 INT64,所以这里会报一个错。 + # 如何分辨 onnx 的张量是不是需要作为张量注册? + # make_and_import_model( + # make_graph([reshape], "reshape", [data, shape], [reshaped]) + # ) # see def test_linear(self):