forked from jiuyuan/InfiniTensor
fix: 改正 reshap 导入
- 从 initializer 拿到 reshape 的 shape 值 - 但 reshape 仍然无法导入,因为无法分辨 shape 其实不是一个后端张量 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
7626efbfa8
commit
d9e2953425
|
@ -7,6 +7,7 @@ def from_onnx(model: onnx.ModelProto):
|
|||
handler = backend.GraphHandlerObj(runtime)
|
||||
|
||||
tensors = dict()
|
||||
data = dict()
|
||||
|
||||
for input in model.graph.input:
|
||||
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
||||
|
@ -16,6 +17,9 @@ def from_onnx(model: onnx.ModelProto):
|
|||
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
|
||||
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
data[initializer.name] = initializer
|
||||
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = handler.matmul(
|
||||
|
@ -115,7 +119,7 @@ def from_onnx(model: onnx.ModelProto):
|
|||
tensors[node.output[0]] = handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
[int(i) for i in tensors[node.input[1]]],
|
||||
data[node.input[1]].int32_data,
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
|
|
|
@ -129,15 +129,17 @@ class TestStringMethods(unittest.TestCase):
|
|||
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
|
||||
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
||||
|
||||
# FIXME INT64 类型不支持
|
||||
# def test_reshape(self):
|
||||
# data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||
# shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
|
||||
# reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
|
||||
# reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
||||
# make_and_import_model(
|
||||
# make_graph([reshape], "reshape", [data, shape], [reshaped])
|
||||
# )
|
||||
def test_reshape(self):
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||
shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
|
||||
reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
|
||||
reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
||||
# FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。
|
||||
# tensor 的类型又不支持 INT64,所以这里会报一个错。
|
||||
# 如何分辨 onnx 的张量是不是需要作为张量注册?
|
||||
# make_and_import_model(
|
||||
# make_graph([reshape], "reshape", [data, shape], [reshaped])
|
||||
# )
|
||||
|
||||
# see <https://onnx.ai/onnx/intro/python.html#a-simple-example-a-linear-regression>
|
||||
def test_linear(self):
|
||||
|
|
Loading…
Reference in New Issue