forked from jiuyuan/InfiniTensor
refactor(frontend): 先排序后构图
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
ce23b8356f
commit
c34946a0d8
|
@ -59,6 +59,31 @@ class OnnxStub:
|
|||
warnings.warn("infer_shapes failed.")
|
||||
self.handler = backend.GraphHandler(runtime)
|
||||
|
||||
# 处理重名和匿名算子
|
||||
names = {}
|
||||
for node in model.graph.node:
|
||||
if node.name == "":
|
||||
node.name = "missing_name(" + node.op_type + ")"
|
||||
if node.name in names:
|
||||
names[node.name] += 1
|
||||
node.name += "_" + str(names[node.name])
|
||||
else:
|
||||
names[node.name] = 0
|
||||
# 拓扑排序
|
||||
sorted_nodes = []
|
||||
known_edge = set(t.name for t in model.graph.input)
|
||||
known_edge.update(t.name for t in model.graph.initializer)
|
||||
while len(sorted_nodes) < len(model.graph.node):
|
||||
updated = False
|
||||
for i, node in enumerate(model.graph.node):
|
||||
if all(t in known_edge for t in node.input):
|
||||
node.name = str(len(sorted_nodes)) + "_" + node.name
|
||||
sorted_nodes.append(i)
|
||||
known_edge.update(node.output)
|
||||
updated = True
|
||||
if not updated:
|
||||
raise Exception("Graph has cycle")
|
||||
|
||||
tensors: Dict[str, backend.Tensor] = dict()
|
||||
data: Dict[str, TensorProto] = dict()
|
||||
|
||||
|
@ -83,17 +108,8 @@ class OnnxStub:
|
|||
)
|
||||
tensors[output.name].set_output()
|
||||
|
||||
node_name = []
|
||||
new_node_name = []
|
||||
for node in model.graph.node:
|
||||
node_name.append(node.name)
|
||||
node_list = model.graph.node
|
||||
while len(node_list) != 0:
|
||||
for node in model.graph.node:
|
||||
if node.name not in node_list:
|
||||
continue
|
||||
if _analyse_node(node, tensors):
|
||||
continue
|
||||
for node_idx in sorted_nodes:
|
||||
node = model.graph.node[node_idx]
|
||||
if node.op_type == "Conv":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
@ -201,8 +217,7 @@ class OnnxStub:
|
|||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
)
|
||||
(alpha, beta, transA, transB) = (
|
||||
attributes[name]
|
||||
for name in ["alpha", "beta", "transA", "transB"]
|
||||
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
|
||||
)
|
||||
# FIXME unsupport attributes: `alpha` `beta`
|
||||
assert alpha == 1.0
|
||||
|
@ -571,9 +586,7 @@ class OnnxStub:
|
|||
tensors[node.output[0]] = self.handler.concat(
|
||||
[tensors[name] for name in node.input],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis")
|
||||
),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "AttentionKVCache":
|
||||
tensors[node.output[0]] = self.handler.attentionKVCache(
|
||||
|
@ -592,11 +605,7 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
None,
|
||||
next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "axis"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||
0,
|
||||
),
|
||||
len(node.output),
|
||||
|
@ -629,19 +638,11 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
# NOTE(constroy): `axes` is an attribute until opset version 13.
|
||||
next(
|
||||
(
|
||||
attr.ints
|
||||
for attr in node.attribute
|
||||
if attr.name == "axes"
|
||||
),
|
||||
(attr.ints for attr in node.attribute if attr.name == "axes"),
|
||||
None,
|
||||
),
|
||||
next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "keepdims"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "keepdims"),
|
||||
1,
|
||||
)
|
||||
!= 0,
|
||||
|
@ -669,9 +670,7 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
_parse_data(data[node.input[1]]),
|
||||
_parse_data(data[node.input[3]])
|
||||
if len(node.input) > 3
|
||||
else None,
|
||||
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
|
||||
)
|
||||
elif node.op_type == "Dropout":
|
||||
for name, tensor in zip(
|
||||
|
@ -679,9 +678,7 @@ class OnnxStub:
|
|||
self.handler.dropout(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1])
|
||||
if len(node.output) > 1
|
||||
else None,
|
||||
tensors.get(node.output[1]) if len(node.output) > 1 else None,
|
||||
_parse_data(data[node.input[1]])[0]
|
||||
if len(node.input) > 1
|
||||
else 0.5,
|
||||
|
@ -785,11 +782,7 @@ class OnnxStub:
|
|||
0,
|
||||
)
|
||||
destination = next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "destination"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "destination"),
|
||||
0,
|
||||
)
|
||||
|
||||
|
@ -805,11 +798,7 @@ class OnnxStub:
|
|||
0,
|
||||
)
|
||||
destination = next(
|
||||
(
|
||||
attr.i
|
||||
for attr in node.attribute
|
||||
if attr.name == "destination"
|
||||
),
|
||||
(attr.i for attr in node.attribute if attr.name == "destination"),
|
||||
0,
|
||||
)
|
||||
|
||||
|
@ -861,16 +850,12 @@ class OnnxStub:
|
|||
elif node.op_type == "DynamicQuantizeLinear":
|
||||
for name, tensor in zip(
|
||||
node.output,
|
||||
self.handler.dynamicQuantizeLinear(
|
||||
tensors[node.input[0]], None
|
||||
),
|
||||
self.handler.dynamicQuantizeLinear(tensors[node.input[0]], None),
|
||||
):
|
||||
tensors[name] = tensor
|
||||
elif node.op_type == "DequantizeLinear":
|
||||
(inputX, inputScale) = (tensors[node.input[i]] for i in [0, 1])
|
||||
inputZeroPoint = (
|
||||
None if len(node.input) < 3 else tensors[node.input[2]]
|
||||
)
|
||||
inputZeroPoint = None if len(node.input) < 3 else tensors[node.input[2]]
|
||||
output = tensors.get(node.output[0])
|
||||
axis = next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||
|
@ -893,9 +878,6 @@ class OnnxStub:
|
|||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
new_node_name.append(node.name)
|
||||
# update the node_list
|
||||
node_list = list(set(node_name) - set(new_node_name))
|
||||
|
||||
################################
|
||||
# Allocate memory space for data
|
||||
|
@ -1361,10 +1343,3 @@ def _parse_data_fp16(tensor: TensorProto):
|
|||
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
||||
|
||||
def _analyse_node(node: NodeProto, tensors) -> bool:
|
||||
for i in node.input:
|
||||
if i not in tensors:
|
||||
return True
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue