forked from jiuyuan/InfiniTensor
fix: 改正 reshap 导入
- 从 initializer 拿到 reshape 的 shape 值 - 但 reshape 仍然无法导入,因为无法分辨 shape 其实不是一个后端张量 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
7626efbfa8
commit
d9e2953425
|
@ -7,6 +7,7 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
handler = backend.GraphHandlerObj(runtime)
|
handler = backend.GraphHandlerObj(runtime)
|
||||||
|
|
||||||
tensors = dict()
|
tensors = dict()
|
||||||
|
data = dict()
|
||||||
|
|
||||||
for input in model.graph.input:
|
for input in model.graph.input:
|
||||||
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
||||||
|
@ -16,6 +17,9 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
|
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
|
||||||
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
|
tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type)
|
||||||
|
|
||||||
|
for initializer in model.graph.initializer:
|
||||||
|
data[initializer.name] = initializer
|
||||||
|
|
||||||
for node in model.graph.node:
|
for node in model.graph.node:
|
||||||
if node.op_type == "MatMul":
|
if node.op_type == "MatMul":
|
||||||
tensors[node.output[0]] = handler.matmul(
|
tensors[node.output[0]] = handler.matmul(
|
||||||
|
@ -115,7 +119,7 @@ def from_onnx(model: onnx.ModelProto):
|
||||||
tensors[node.output[0]] = handler.reshape(
|
tensors[node.output[0]] = handler.reshape(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
[int(i) for i in tensors[node.input[1]]],
|
data[node.input[1]].int32_data,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
|
@ -129,12 +129,14 @@ class TestStringMethods(unittest.TestCase):
|
||||||
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
|
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
|
||||||
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
|
||||||
|
|
||||||
# FIXME INT64 类型不支持
|
def test_reshape(self):
|
||||||
# def test_reshape(self):
|
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
||||||
# data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
|
shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
|
||||||
# shape = make_tensor_value_info("shape", TensorProto.INT64, [3, 5, 8])
|
reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
|
||||||
# reshaped = make_tensor_value_info("reshaped", TensorProto.FLOAT, [3, 5, 8])
|
reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
||||||
# reshape = make_node("Reshape", ["data", "shape"], ["reshaped"], name="reshape")
|
# FIXME shape 对于 onnx 来说是输入张量,但对于后端来说不是,导入时无法分辨这个情况。
|
||||||
|
# tensor 的类型又不支持 INT64,所以这里会报一个错。
|
||||||
|
# 如何分辨 onnx 的张量是不是需要作为张量注册?
|
||||||
# make_and_import_model(
|
# make_and_import_model(
|
||||||
# make_graph([reshape], "reshape", [data, shape], [reshaped])
|
# make_graph([reshape], "reshape", [data, shape], [reshaped])
|
||||||
# )
|
# )
|
||||||
|
|
Loading…
Reference in New Issue