refactor(frontend): 先排序后构图

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-12-25 17:58:29 +08:00
parent ce23b8356f
commit c34946a0d8
1 changed files with 736 additions and 761 deletions

View File

@ -59,6 +59,31 @@ class OnnxStub:
warnings.warn("infer_shapes failed.")
self.handler = backend.GraphHandler(runtime)
# 处理重名和匿名算子
names = {}
for node in model.graph.node:
if node.name == "":
node.name = "missing_name(" + node.op_type + ")"
if node.name in names:
names[node.name] += 1
node.name += "_" + str(names[node.name])
else:
names[node.name] = 0
# 拓扑排序
sorted_nodes = []
known_edge = set(t.name for t in model.graph.input)
known_edge.update(t.name for t in model.graph.initializer)
while len(sorted_nodes) < len(model.graph.node):
updated = False
for i, node in enumerate(model.graph.node):
if all(t in known_edge for t in node.input):
node.name = str(len(sorted_nodes)) + "_" + node.name
sorted_nodes.append(i)
known_edge.update(node.output)
updated = True
if not updated:
raise Exception("Graph has cycle")
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
@ -83,17 +108,8 @@ class OnnxStub:
)
tensors[output.name].set_output()
node_name = []
new_node_name = []
for node in model.graph.node:
node_name.append(node.name)
node_list = model.graph.node
while len(node_list) != 0:
for node in model.graph.node:
if node.name not in node_list:
continue
if _analyse_node(node, tensors):
continue
for node_idx in sorted_nodes:
node = model.graph.node[node_idx]
if node.op_type == "Conv":
attributes = _parse_attribute(
node,
@ -201,8 +217,7 @@ class OnnxStub:
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
)
(alpha, beta, transA, transB) = (
attributes[name]
for name in ["alpha", "beta", "transA", "transB"]
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
)
# FIXME unsupport attributes: `alpha` `beta`
assert alpha == 1.0
@ -571,9 +586,7 @@ class OnnxStub:
tensors[node.output[0]] = self.handler.concat(
[tensors[name] for name in node.input],
tensors.get(node.output[0]),
next(
(attr.i for attr in node.attribute if attr.name == "axis")
),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "AttentionKVCache":
tensors[node.output[0]] = self.handler.attentionKVCache(
@ -592,11 +605,7 @@ class OnnxStub:
tensors[node.input[0]],
None,
next(
(
attr.i
for attr in node.attribute
if attr.name == "axis"
),
(attr.i for attr in node.attribute if attr.name == "axis"),
0,
),
len(node.output),
@ -629,19 +638,11 @@ class OnnxStub:
tensors.get(node.output[0]),
# NOTE(constroy): `axes` is an attribute until opset version 13.
next(
(
attr.ints
for attr in node.attribute
if attr.name == "axes"
),
(attr.ints for attr in node.attribute if attr.name == "axes"),
None,
),
next(
(
attr.i
for attr in node.attribute
if attr.name == "keepdims"
),
(attr.i for attr in node.attribute if attr.name == "keepdims"),
1,
)
!= 0,
@ -669,9 +670,7 @@ class OnnxStub:
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[3]])
if len(node.input) > 3
else None,
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
)
elif node.op_type == "Dropout":
for name, tensor in zip(
@ -679,9 +678,7 @@ class OnnxStub:
self.handler.dropout(
tensors[node.input[0]],
tensors.get(node.output[0]),
tensors.get(node.output[1])
if len(node.output) > 1
else None,
tensors.get(node.output[1]) if len(node.output) > 1 else None,
_parse_data(data[node.input[1]])[0]
if len(node.input) > 1
else 0.5,
@ -785,11 +782,7 @@ class OnnxStub:
0,
)
destination = next(
(
attr.i
for attr in node.attribute
if attr.name == "destination"
),
(attr.i for attr in node.attribute if attr.name == "destination"),
0,
)
@ -805,11 +798,7 @@ class OnnxStub:
0,
)
destination = next(
(
attr.i
for attr in node.attribute
if attr.name == "destination"
),
(attr.i for attr in node.attribute if attr.name == "destination"),
0,
)
@ -861,16 +850,12 @@ class OnnxStub:
elif node.op_type == "DynamicQuantizeLinear":
for name, tensor in zip(
node.output,
self.handler.dynamicQuantizeLinear(
tensors[node.input[0]], None
),
self.handler.dynamicQuantizeLinear(tensors[node.input[0]], None),
):
tensors[name] = tensor
elif node.op_type == "DequantizeLinear":
(inputX, inputScale) = (tensors[node.input[i]] for i in [0, 1])
inputZeroPoint = (
None if len(node.input) < 3 else tensors[node.input[2]]
)
inputZeroPoint = None if len(node.input) < 3 else tensors[node.input[2]]
output = tensors.get(node.output[0])
axis = next(
(attr.i for attr in node.attribute if attr.name == "axis"),
@ -893,9 +878,6 @@ class OnnxStub:
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
new_node_name.append(node.name)
# update the node_list
node_list = list(set(node_name) - set(new_node_name))
################################
# Allocate memory space for data
@ -1361,10 +1343,3 @@ def _parse_data_fp16(tensor: TensorProto):
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
def _analyse_node(node: NodeProto, tensors) -> bool:
for i in node.input:
if i not in tensors:
return True
return False