forked from jiuyuan/InfiniTensor
feat: 支持 tensor attribute
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
ed66ab6a37
commit
ea1d2c0129
|
@ -0,0 +1 @@
|
|||
Subproject commit 51d3105277f3774ed31c02ed4cd11fa92925af77
|
|
@ -1,23 +1,16 @@
|
|||
import backend
|
||||
from onnx import ModelProto, NodeProto, AttributeProto, numpy_helper
|
||||
from onnx import ModelProto, NodeProto, TensorProto, AttributeProto, numpy_helper
|
||||
from backend import DimExpr, refactor_tensor
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_graph(model: ModelProto):
|
||||
print("Building graph...")
|
||||
|
||||
nodes: dict[str, backend.Node] = dict()
|
||||
edges: dict[str, backend.Edge] = dict()
|
||||
topology: dict[str, tuple[list[str], list[str]]] = dict()
|
||||
|
||||
for tensor in model.graph.initializer:
|
||||
edges[tensor.name] = refactor_tensor(
|
||||
tensor.data_type,
|
||||
[DimExpr(d) for d in tensor.dims],
|
||||
[b for b in numpy_helper.to_array(tensor).data.tobytes()],
|
||||
)
|
||||
print("tensor: {}".format(tensor.name))
|
||||
edges[tensor.name] = _parse_tensor(tensor)
|
||||
|
||||
for tensor in model.graph.input:
|
||||
if tensor.name not in edges:
|
||||
|
@ -28,14 +21,12 @@ def build_graph(model: ModelProto):
|
|||
edges[tensor.name] = refactor_tensor(
|
||||
tensor.type.tensor_type.elem_type, dim, None
|
||||
)
|
||||
print("tensor: {}".format(tensor.name))
|
||||
|
||||
for node in model.graph.node:
|
||||
topology[node.name] = ([i for i in node.input], [o for o in node.output])
|
||||
nodes[node.name] = backend.refactor_operator(
|
||||
node.op_type, _parse_attribute(node)
|
||||
)
|
||||
print("node: {}".format(node.name))
|
||||
|
||||
graph = backend.refactor_graph(
|
||||
topology,
|
||||
|
@ -44,7 +35,14 @@ def build_graph(model: ModelProto):
|
|||
[i.name for i in model.graph.input],
|
||||
[o.name for o in model.graph.output],
|
||||
)
|
||||
print("Graph built.")
|
||||
|
||||
|
||||
def _parse_tensor(tensor: TensorProto) -> backend.Tensor:
|
||||
refactor_tensor(
|
||||
tensor.data_type,
|
||||
[DimExpr(d) for d in tensor.dims],
|
||||
[b for b in numpy_helper.to_array(tensor).data.tobytes()],
|
||||
)
|
||||
|
||||
|
||||
def _parse_attribute(node: NodeProto) -> dict[str, Any]:
|
||||
|
@ -62,8 +60,10 @@ def _parse_attribute(node: NodeProto) -> dict[str, Any]:
|
|||
ans[attr.name] = attr.s
|
||||
elif attr.type == AttributeProto.STRINGS:
|
||||
ans[attr.name] = attr.strings
|
||||
# elif attr.type == AttributeProto.TENSOR:
|
||||
# ans[attr.name] = attr.t
|
||||
elif attr.type == AttributeProto.TENSOR:
|
||||
ans[attr.name] = _parse_tensor(attr.t)
|
||||
elif attr.type == AttributeProto.TENSORS:
|
||||
ans[attr.name] = [_parse_tensor(t) for t in attr.tensors]
|
||||
else:
|
||||
assert False, "Unsupported Attribute Type: {}".format(attr.type)
|
||||
return ans
|
||||
|
|
Loading…
Reference in New Issue