feat: 支持 tensor attribute

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-08 16:57:38 +08:00
parent ed66ab6a37
commit ea1d2c0129
2 changed files with 15 additions and 14 deletions

1
example Submodule

@ -0,0 +1 @@
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77

View File

@ -1,23 +1,16 @@
import backend
from onnx import ModelProto, NodeProto, AttributeProto, numpy_helper
from onnx import ModelProto, NodeProto, TensorProto, AttributeProto, numpy_helper
from backend import DimExpr, refactor_tensor
from typing import Any
def build_graph(model: ModelProto):
print("Building graph...")
nodes: dict[str, backend.Node] = dict()
edges: dict[str, backend.Edge] = dict()
topology: dict[str, tuple[list[str], list[str]]] = dict()
for tensor in model.graph.initializer:
edges[tensor.name] = refactor_tensor(
tensor.data_type,
[DimExpr(d) for d in tensor.dims],
[b for b in numpy_helper.to_array(tensor).data.tobytes()],
)
print("tensor: {}".format(tensor.name))
edges[tensor.name] = _parse_tensor(tensor)
for tensor in model.graph.input:
if tensor.name not in edges:
@ -28,14 +21,12 @@ def build_graph(model: ModelProto):
edges[tensor.name] = refactor_tensor(
tensor.type.tensor_type.elem_type, dim, None
)
print("tensor: {}".format(tensor.name))
for node in model.graph.node:
topology[node.name] = ([i for i in node.input], [o for o in node.output])
nodes[node.name] = backend.refactor_operator(
node.op_type, _parse_attribute(node)
)
print("node: {}".format(node.name))
graph = backend.refactor_graph(
topology,
@ -44,7 +35,14 @@ def build_graph(model: ModelProto):
[i.name for i in model.graph.input],
[o.name for o in model.graph.output],
)
print("Graph built.")
def _parse_tensor(tensor: TensorProto) -> backend.Tensor:
refactor_tensor(
tensor.data_type,
[DimExpr(d) for d in tensor.dims],
[b for b in numpy_helper.to_array(tensor).data.tobytes()],
)
def _parse_attribute(node: NodeProto) -> dict[str, Any]:
@ -62,8 +60,10 @@ def _parse_attribute(node: NodeProto) -> dict[str, Any]:
ans[attr.name] = attr.s
elif attr.type == AttributeProto.STRINGS:
ans[attr.name] = attr.strings
# elif attr.type == AttributeProto.TENSOR:
# ans[attr.name] = attr.t
elif attr.type == AttributeProto.TENSOR:
ans[attr.name] = _parse_tensor(attr.t)
elif attr.type == AttributeProto.TENSORS:
ans[attr.name] = [_parse_tensor(t) for t in attr.tensors]
else:
assert False, "Unsupported Attribute Type: {}".format(attr.type)
return ans