diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 74ded944..3c66de4d 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -59,6 +59,31 @@ class OnnxStub: warnings.warn("infer_shapes failed.") self.handler = backend.GraphHandler(runtime) + # 处理重名和匿名算子 + names = {} + for node in model.graph.node: + if node.name == "": + node.name = "missing_name(" + node.op_type + ")" + if node.name in names: + names[node.name] += 1 + node.name += "_" + str(names[node.name]) + else: + names[node.name] = 0 + # 拓扑排序 + sorted_nodes = [] + known_edge = set(t.name for t in model.graph.input) + known_edge.update(t.name for t in model.graph.initializer) + while len(sorted_nodes) < len(model.graph.node): + updated = False + for i, node in enumerate(model.graph.node): + if all(t in known_edge for t in node.input): + node.name = str(len(sorted_nodes)) + "_" + node.name + sorted_nodes.append(i) + known_edge.update(node.output) + updated = True + if not updated: + raise Exception("Graph has cycle") + tensors: Dict[str, backend.Tensor] = dict() data: Dict[str, TensorProto] = dict() @@ -83,98 +108,64 @@ class OnnxStub: ) tensors[output.name].set_output() - node_name = [] - new_node_name = [] - for node in model.graph.node: - node_name.append(node.name) - node_list = model.graph.node - while len(node_list) != 0: - for node in model.graph.node: - if node.name not in node_list: - continue - if _analyse_node(node, tensors): - continue - if node.op_type == "Conv": - attributes = _parse_attribute( - node, - { - "dilations": [1, 1], - "pads": [0, 0, 0, 0], - "strides": [1, 1], - }, + for node_idx in sorted_nodes: + node = model.graph.node[node_idx] + if node.op_type == "Conv": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0, 0, 0], + "strides": [1, 1], + }, + ) + (d, p, s) = ( + attributes[name] for name in ["dilations", "pads", "strides"] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors[node.input[0]], None, p, [-2, -1] ) - (d, p, s) = ( - attributes[name] for name in ["dilations", "pads", "strides"] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors[node.input[0]], None, p, [-2, -1] - ) - p = [0, 0, 0, 0] - else: - adapt = node.input[0] + p = [0, 0, 0, 0] + else: + adapt = node.input[0] - if len(node.input) > 2: - bias = "{}-bias".format(node.output[0]) - reshape = "{}-reshape".format(node.output[0]) - tensors[bias] = self.handler.conv( - tensors[adapt], - tensors[node.input[1]], - None, - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], - ) - tensors[reshape] = self.handler.reshape( - tensors[node.input[2]], - None, - [ - 1, - reduce( - lambda acc, x: acc * x, - _search_shape(model, node.input[2]), - ), - 1, - 1, - ], - ) - tensors[node.output[0]] = self.handler.add( - tensors[bias], - tensors[reshape], - tensors.get(node.output[0]), - ) - else: - tensors[node.output[0]] = self.handler.conv( - tensors[adapt], - tensors[node.input[1]], - tensors.get(node.output[0]), - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], - ) - elif node.op_type == "ConvTranspose": - attributes = _parse_attribute( - node, - { - "dilations": [1, 1], - "pads": [0, 0], - "strides": [1, 1], - "output_padding": [0, 0], - }, + if len(node.input) > 2: + bias = "{}-bias".format(node.output[0]) + reshape = "{}-reshape".format(node.output[0]) + tensors[bias] = self.handler.conv( + tensors[adapt], + tensors[node.input[1]], + None, + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], ) - (d, p, s, op) = ( - attributes[name] - for name in ["dilations", "pads", "strides", "output_padding"] + tensors[reshape] = self.handler.reshape( + tensors[node.input[2]], + None, + [ + 1, + reduce( + lambda acc, x: acc * x, + _search_shape(model, node.input[2]), + ), + 1, + 1, + ], ) - tensors[node.output[0]] = self.handler.convTransposed2d( - tensors[node.input[0]], + tensors[node.output[0]] = self.handler.add( + tensors[bias], + tensors[reshape], + tensors.get(node.output[0]), + ) + else: + tensors[node.output[0]] = self.handler.conv( + tensors[adapt], tensors[node.input[1]], tensors.get(node.output[0]), p[0], @@ -183,459 +174,547 @@ class OnnxStub: s[1], d[0], d[1], - op[0], - op[1], ) - elif node.op_type == "MatMul": - tensors[node.output[0]] = self.handler.matmul( - tensors[node.input[0]], - tensors[node.input[1]], + elif node.op_type == "ConvTranspose": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0], + "strides": [1, 1], + "output_padding": [0, 0], + }, + ) + (d, p, s, op) = ( + attributes[name] + for name in ["dilations", "pads", "strides", "output_padding"] + ) + tensors[node.output[0]] = self.handler.convTransposed2d( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + op[0], + op[1], + ) + elif node.op_type == "MatMul": + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + False, + False, + None, + backend.ActType.Linear, + ) + elif node.op_type == "Gemm": + attributes = _parse_attribute( + node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} + ) + (alpha, beta, transA, transB) = ( + attributes[name] for name in ["alpha", "beta", "transA", "transB"] + ) + # FIXME unsupport attributes: `alpha` `beta` + assert alpha == 1.0 + assert beta == 1.0 + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + transA == 1, + transB == 1, + tensors[node.input[2]] if len(node.input) > 2 else None, + backend.ActType.Linear, + ) + elif node.op_type == "BatchNormalization": + (input, mean, var, scale, bias) = ( + tensors[node.input[i]] for i in [0, 3, 4, 1, 2] + ) + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} + ) + (momentum, eps, training) = ( + attributes[name] + for name in ["momentum", "epsilon", "training_mode"] + ) + tensors[node.output[0]] = self.handler.batchNormalization( + input, + output, + mean, + var, + scale, + bias, + momentum, + eps, + training != 0, + ) + elif node.op_type == "LayerNormalization": + (input, scale) = (tensors[node.input[i]] for i in [0, 1]) + bias = None if len(node.input) < 3 else tensors[node.input[2]] + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1} + ) + (axis, eps, stash_type) = ( + attributes[name] for name in ["axis", "epsilon", "stash_type"] + ) + tensors[node.output[0]] = self.handler.layerNormalization( + input, + scale, + output, + bias, + eps, + axis, + stash_type, + ) + elif node.op_type == "MaxPool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "dilations": [1, 1], + "pads": [0, 0, 0, 0], + "strides": [1, 1], + "ceil_mode": 0, + }, + ) + (k, d, p, s, ceil_mode) = ( + attributes[name] + for name in [ + "kernel_shape", + "dilations", + "pads", + "strides", + "ceil_mode", + ] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors.get(node.input[0]), None, p, [-2, -1] + ) + tensors[node.output[0]] = self.handler.maxPool( + tensors[adapt], tensors.get(node.output[0]), - False, - False, - None, - backend.ActType.Linear, + k[0], + k[1], + d[0], + d[1], + 0, + 0, + s[0], + s[1], + ceil_mode, ) - elif node.op_type == "Gemm": - attributes = _parse_attribute( - node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} - ) - (alpha, beta, transA, transB) = ( - attributes[name] - for name in ["alpha", "beta", "transA", "transB"] - ) - # FIXME unsupport attributes: `alpha` `beta` - assert alpha == 1.0 - assert beta == 1.0 - tensors[node.output[0]] = self.handler.matmul( + else: + tensors[node.output[0]] = self.handler.maxPool( tensors[node.input[0]], - tensors[node.input[1]], tensors.get(node.output[0]), - transA == 1, - transB == 1, - tensors[node.input[2]] if len(node.input) > 2 else None, - backend.ActType.Linear, + k[0], + k[1], + d[0], + d[1], + p[0], + p[1], + s[0], + s[1], + ceil_mode, ) - elif node.op_type == "BatchNormalization": - (input, mean, var, scale, bias) = ( - tensors[node.input[i]] for i in [0, 3, 4, 1, 2] + elif node.op_type == "AveragePool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "pads": [0, 0, 0, 0], + "strides": [1, 1], + "ceil_mode": 0, + }, + ) + (k, p, s, ceil_mode) = ( + attributes[name] + for name in ["kernel_shape", "pads", "strides", "ceil_mode"] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors.get(node.input[0]), None, p, [-2, -1] ) - output = tensors.get(node.output[0]) - attributes = _parse_attribute( - node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} + tensors[node.output[0]] = self.handler.avgPool( + tensors[adapt], + tensors.get(node.output[0]), + k[0], + k[1], + 1, + 1, + 0, + 0, + s[0], + s[1], + ceil_mode, ) - (momentum, eps, training) = ( - attributes[name] - for name in ["momentum", "epsilon", "training_mode"] - ) - tensors[node.output[0]] = self.handler.batchNormalization( - input, - output, - mean, - var, - scale, - bias, - momentum, - eps, - training != 0, - ) - elif node.op_type == "LayerNormalization": - (input, scale) = (tensors[node.input[i]] for i in [0, 1]) - bias = None if len(node.input) < 3 else tensors[node.input[2]] - output = tensors.get(node.output[0]) - attributes = _parse_attribute( - node, {"axis": -1, "epsilon": 1e-05, "stash_type": 1} - ) - (axis, eps, stash_type) = ( - attributes[name] for name in ["axis", "epsilon", "stash_type"] - ) - tensors[node.output[0]] = self.handler.layerNormalization( - input, - scale, - output, - bias, - eps, - axis, - stash_type, - ) - elif node.op_type == "MaxPool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "dilations": [1, 1], - "pads": [0, 0, 0, 0], - "strides": [1, 1], - "ceil_mode": 0, - }, - ) - (k, d, p, s, ceil_mode) = ( - attributes[name] - for name in [ - "kernel_shape", - "dilations", - "pads", - "strides", - "ceil_mode", - ] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors.get(node.input[0]), None, p, [-2, -1] - ) - tensors[node.output[0]] = self.handler.maxPool( - tensors[adapt], - tensors.get(node.output[0]), - k[0], - k[1], - d[0], - d[1], - 0, - 0, - s[0], - s[1], - ceil_mode, - ) - else: - tensors[node.output[0]] = self.handler.maxPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - k[0], - k[1], - d[0], - d[1], - p[0], - p[1], - s[0], - s[1], - ceil_mode, - ) - elif node.op_type == "AveragePool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "pads": [0, 0, 0, 0], - "strides": [1, 1], - "ceil_mode": 0, - }, - ) - (k, p, s, ceil_mode) = ( - attributes[name] - for name in ["kernel_shape", "pads", "strides", "ceil_mode"] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors.get(node.input[0]), None, p, [-2, -1] - ) - tensors[node.output[0]] = self.handler.avgPool( - tensors[adapt], - tensors.get(node.output[0]), - k[0], - k[1], - 1, - 1, - 0, - 0, - s[0], - s[1], - ceil_mode, - ) - else: - tensors[node.output[0]] = self.handler.avgPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - k[0], - k[1], - 1, - 1, - p[0], - p[1], - s[0], - s[1], - ceil_mode, - ) - elif node.op_type == "GlobalAveragePool": - [_, _, h, w] = _search_shape(model, node.input[0]) + else: tensors[node.output[0]] = self.handler.avgPool( tensors[node.input[0]], tensors.get(node.output[0]), - h, - w, + k[0], + k[1], 1, 1, - 0, - 0, + p[0], + p[1], + s[0], + s[1], + ceil_mode, + ) + elif node.op_type == "GlobalAveragePool": + [_, _, h, w] = _search_shape(model, node.input[0]) + tensors[node.output[0]] = self.handler.avgPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + h, + w, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + ) + elif node.op_type == "Add": + tensors[node.output[0]] = self.handler.add( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sub": + tensors[node.output[0]] = self.handler.sub( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Mul": + tensors[node.output[0]] = self.handler.mul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Div": + tensors[node.output[0]] = self.handler.div( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Pow": + tensors[node.output[0]] = self.handler.pow( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Min": + tensors[node.output[0]] = self.handler.min( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Max": + tensors[node.output[0]] = self.handler.max( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Relu": + tensors[node.output[0]] = self.handler.relu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Gelu": + tensors[node.output[0]] = self.handler.gelu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sigmoid": + tensors[node.output[0]] = self.handler.sigmoid( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "HardSigmoid": + tensors[node.output[0]] = self.handler.hardSigmoid( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "HardSwish": + tensors[node.output[0]] = self.handler.hardSwish( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Tanh": + tensors[node.output[0]] = self.handler.tanh( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Softmax": + tensors[node.output[0]] = self.handler.softmax( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), + -1, + ), + ) + elif node.op_type == "Abs": + tensors[node.output[0]] = self.handler.abs( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sqrt": + tensors[node.output[0]] = self.handler.sqrt( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Neg": + tensors[node.output[0]] = self.handler.neg( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Shape": + tensors[node.output[0]] = self.handler.shape( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Identity": + tensors[node.output[0]] = self.handler.identity( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Flatten": + tensors[node.output[0]] = self.handler.flatten( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), 1, - 1, - 0, - ) - elif node.op_type == "Add": - tensors[node.output[0]] = self.handler.add( + ), + ) + elif node.op_type == "PRelu": + tensors[node.output[0]] = self.handler.pRelu( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Clip": + tensors[node.output[0]] = self.handler.clip( + tensors[node.input[0]], + tensors.get(node.output[0]), + next(_parse_data(data[node.input[1]]).__iter__(), None) + if len(node.input) > 1 + else None, + next(_parse_data(data[node.input[2]]).__iter__(), None) + if len(node.input) > 2 + else None, + ) + elif node.op_type == "Transpose": + perm = next( + (attr.ints for attr in node.attribute if attr.name == "perm"), + None, + ) + tensors[node.output[0]] = self.handler.transpose( + tensors[node.input[0]], + tensors.get(node.output[0]), + perm, + ) + elif node.op_type == "DepthToSpace": + blocksize = next( + (attr.i for attr in node.attribute if attr.name == "blocksize"), + None, + ) + mode = next( + (attr.s for attr in node.attribute if attr.name == "mode"), + None, + ) + tensors[node.output[0]] = self.handler.depthToSpace( + tensors[node.input[0]], + tensors.get(node.output[0]), + blocksize, + mode, + ) + elif node.op_type == "Reshape": + shape = _parse_data(data[node.input[1]]) + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + shape, + ) + elif node.op_type == "Squeeze": + input_shape = _search_shape(model, node.input[0]) + axes = set( + [int(i) for i in data[node.input[1]].int64_data] + if len(node.input) > 1 + else _parse_attribute(node, {"axes": None})["axes"] + ) + assert all(input_shape[d] == 1 for d in axes) + output_shape = [] + for i, x in enumerate(input_shape): + if i not in axes: + output_shape.append(x) + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + output_shape, + ) + elif node.op_type == "Unsqueeze": + input_shape = _search_shape(model, node.input[0]) + axes = ( + [int(i) for i in data[node.input[1]].int64_data] + if len(node.input) > 1 + else _parse_attribute(node, {"axes": None})["axes"] + ) + for i in axes: + input_shape.insert(i, 1) + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + input_shape, + ) + elif node.op_type == "Concat": + tensors[node.output[0]] = self.handler.concat( + [tensors[name] for name in node.input], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), + ) + elif node.op_type == "AttentionKVCache": + tensors[node.output[0]] = self.handler.attentionKVCache( + tensors[node.input[0]], + tensors[node.input[1]], + tensors[node.input[2]], + tensors[node.input[3]], + tensors[node.input[4]], + tensors[node.input[5]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Split": + for name, tensor in zip( + node.output, + self.handler.split( tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sub": - tensors[node.output[0]] = self.handler.sub( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Mul": - tensors[node.output[0]] = self.handler.mul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Div": - tensors[node.output[0]] = self.handler.div( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Pow": - tensors[node.output[0]] = self.handler.pow( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Min": - tensors[node.output[0]] = self.handler.min( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Max": - tensors[node.output[0]] = self.handler.max( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Relu": - tensors[node.output[0]] = self.handler.relu( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Gelu": - tensors[node.output[0]] = self.handler.gelu( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sigmoid": - tensors[node.output[0]] = self.handler.sigmoid( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "HardSigmoid": - tensors[node.output[0]] = self.handler.hardSigmoid( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "HardSwish": - tensors[node.output[0]] = self.handler.hardSwish( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Tanh": - tensors[node.output[0]] = self.handler.tanh( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Softmax": - tensors[node.output[0]] = self.handler.softmax( - tensors[node.input[0]], - tensors.get(node.output[0]), + None, next( (attr.i for attr in node.attribute if attr.name == "axis"), - -1, + 0, ), + len(node.output), + ), + ): + tensors[name] = tensor + elif node.op_type == "Gather": + tensors[node.output[0]] = self.handler.gather( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), + 0, + ), + ) + elif node.op_type == "GatherElements": + tensors[node.output[0]] = self.handler.gatherElements( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis"), + 0, + ), + ) + elif node.op_type == "ReduceMean": + tensors[node.output[0]] = self.handler.reduceMean( + tensors[node.input[0]], + tensors.get(node.output[0]), + # NOTE(constroy): `axes` is an attribute until opset version 13. + next( + (attr.ints for attr in node.attribute if attr.name == "axes"), + None, + ), + next( + (attr.i for attr in node.attribute if attr.name == "keepdims"), + 1, ) - elif node.op_type == "Abs": - tensors[node.output[0]] = self.handler.abs( + != 0, + ) + elif node.op_type == "Slice": + + def clamp(nums): + MAX_INT = 0x7FFFFFFF + return [min(x, MAX_INT) for x in nums] + + tensors[node.output[0]] = self.handler.slice( + tensors[node.input[0]], + tensors.get(node.output[0]), + clamp(_parse_data(data[node.input[1]])), + clamp(_parse_data(data[node.input[2]])), + clamp(_parse_data(data[node.input[3]])) + if len(node.input) > 3 + else None, + clamp(_parse_data(data[node.input[4]])) + if len(node.input) > 4 + else None, + ) + elif node.op_type == "Pad": + tensors[node.output[0]] = self.handler.pad( + tensors[node.input[0]], + tensors.get(node.output[0]), + _parse_data(data[node.input[1]]), + _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, + ) + elif node.op_type == "Dropout": + for name, tensor in zip( + node.output, + self.handler.dropout( tensors[node.input[0]], tensors.get(node.output[0]), - ) - elif node.op_type == "Sqrt": - tensors[node.output[0]] = self.handler.sqrt( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Neg": - tensors[node.output[0]] = self.handler.neg( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Shape": - tensors[node.output[0]] = self.handler.shape( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Identity": - tensors[node.output[0]] = self.handler.identity( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Flatten": - tensors[node.output[0]] = self.handler.flatten( - tensors[node.input[0]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), - 1, - ), - ) - elif node.op_type == "PRelu": - tensors[node.output[0]] = self.handler.pRelu( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Clip": - tensors[node.output[0]] = self.handler.clip( - tensors[node.input[0]], - tensors.get(node.output[0]), - next(_parse_data(data[node.input[1]]).__iter__(), None) + tensors.get(node.output[1]) if len(node.output) > 1 else None, + _parse_data(data[node.input[1]])[0] if len(node.input) > 1 - else None, - next(_parse_data(data[node.input[2]]).__iter__(), None) + else 0.5, + _parse_data(data[node.input[2]])[0] if len(node.input) > 2 - else None, - ) - elif node.op_type == "Transpose": - perm = next( - (attr.ints for attr in node.attribute if attr.name == "perm"), - None, - ) - tensors[node.output[0]] = self.handler.transpose( + else False, + ), + ): + tensors[name] = tensor + elif node.op_type == "Cast": + tensors[node.output[0]] = self.handler.cast( + tensors[node.input[0]], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "to")), + ) + elif node.op_type == "ReduceSum": + if any(attr.name == "communicator" for attr in node.attribute): + # ReduceSum with communicator is treated as allReduceSum. + tensors[node.output[0]] = self.handler.allReduceSum( tensors[node.input[0]], tensors.get(node.output[0]), - perm, ) - elif node.op_type == "DepthToSpace": - blocksize = next( - (attr.i for attr in node.attribute if attr.name == "blocksize"), - None, - ) - mode = next( - (attr.s for attr in node.attribute if attr.name == "mode"), - None, - ) - tensors[node.output[0]] = self.handler.depthToSpace( - tensors[node.input[0]], - tensors.get(node.output[0]), - blocksize, - mode, - ) - elif node.op_type == "Reshape": - shape = _parse_data(data[node.input[1]]) - tensors[node.output[0]] = self.handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - shape, - ) - elif node.op_type == "Squeeze": - input_shape = _search_shape(model, node.input[0]) - axes = set( - [int(i) for i in data[node.input[1]].int64_data] - if len(node.input) > 1 - else _parse_attribute(node, {"axes": None})["axes"] - ) - assert all(input_shape[d] == 1 for d in axes) - output_shape = [] - for i, x in enumerate(input_shape): - if i not in axes: - output_shape.append(x) - tensors[node.output[0]] = self.handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - output_shape, - ) - elif node.op_type == "Unsqueeze": - input_shape = _search_shape(model, node.input[0]) - axes = ( - [int(i) for i in data[node.input[1]].int64_data] - if len(node.input) > 1 - else _parse_attribute(node, {"axes": None})["axes"] - ) - for i in axes: - input_shape.insert(i, 1) - tensors[node.output[0]] = self.handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - input_shape, - ) - elif node.op_type == "Concat": - tensors[node.output[0]] = self.handler.concat( - [tensors[name] for name in node.input], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis") - ), - ) - elif node.op_type == "AttentionKVCache": - tensors[node.output[0]] = self.handler.attentionKVCache( - tensors[node.input[0]], - tensors[node.input[1]], - tensors[node.input[2]], - tensors[node.input[3]], - tensors[node.input[4]], - tensors[node.input[5]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Split": - for name, tensor in zip( - node.output, - self.handler.split( - tensors[node.input[0]], - None, - next( - ( - attr.i - for attr in node.attribute - if attr.name == "axis" - ), - 0, - ), - len(node.output), - ), - ): - tensors[name] = tensor - elif node.op_type == "Gather": - tensors[node.output[0]] = self.handler.gather( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), - 0, - ), - ) - elif node.op_type == "GatherElements": - tensors[node.output[0]] = self.handler.gatherElements( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), - 0, - ), - ) - elif node.op_type == "ReduceMean": - tensors[node.output[0]] = self.handler.reduceMean( - tensors[node.input[0]], - tensors.get(node.output[0]), - # NOTE(constroy): `axes` is an attribute until opset version 13. - next( + else: + # NOTE: `axes` is an attribute until opset version 13. + if len(node.input) > 1: + axis = _parse_data(data[node.input[1]]) + else: + axis = next( ( attr.ints for attr in node.attribute if attr.name == "axes" ), None, - ), + ) + keepdims = ( next( ( attr.i @@ -644,258 +723,161 @@ class OnnxStub: ), 1, ) - != 0, - ) - elif node.op_type == "Slice": - - def clamp(nums): - MAX_INT = 0x7FFFFFFF - return [min(x, MAX_INT) for x in nums] - - tensors[node.output[0]] = self.handler.slice( - tensors[node.input[0]], - tensors.get(node.output[0]), - clamp(_parse_data(data[node.input[1]])), - clamp(_parse_data(data[node.input[2]])), - clamp(_parse_data(data[node.input[3]])) - if len(node.input) > 3 - else None, - clamp(_parse_data(data[node.input[4]])) - if len(node.input) > 4 - else None, - ) - elif node.op_type == "Pad": - tensors[node.output[0]] = self.handler.pad( - tensors[node.input[0]], - tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[3]]) - if len(node.input) > 3 - else None, - ) - elif node.op_type == "Dropout": - for name, tensor in zip( - node.output, - self.handler.dropout( - tensors[node.input[0]], - tensors.get(node.output[0]), - tensors.get(node.output[1]) - if len(node.output) > 1 - else None, - _parse_data(data[node.input[1]])[0] - if len(node.input) > 1 - else 0.5, - _parse_data(data[node.input[2]])[0] - if len(node.input) > 2 - else False, - ), - ): - tensors[name] = tensor - elif node.op_type == "Cast": - tensors[node.output[0]] = self.handler.cast( - tensors[node.input[0]], - tensors.get(node.output[0]), - next((attr.i for attr in node.attribute if attr.name == "to")), - ) - elif node.op_type == "ReduceSum": - if any(attr.name == "communicator" for attr in node.attribute): - # ReduceSum with communicator is treated as allReduceSum. - tensors[node.output[0]] = self.handler.allReduceSum( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - else: - # NOTE: `axes` is an attribute until opset version 13. - if len(node.input) > 1: - axis = _parse_data(data[node.input[1]]) - else: - axis = next( - ( - attr.ints - for attr in node.attribute - if attr.name == "axes" - ), - None, - ) - keepdims = ( - next( - ( - attr.i - for attr in node.attribute - if attr.name == "keepdims" - ), - 1, - ) - != 0 - ) - - tensors[node.output[0]] = self.handler.reduceSum( - tensors[node.input[0]], - tensors.get(node.output[0]), - axis, - keepdims, - ) - elif node.op_type == "AllReduceSum": - tensors[node.output[0]] = self.handler.allReduceSum( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceProd": - tensors[node.output[0]] = self.handler.allReduceProd( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceMin": - tensors[node.output[0]] = self.handler.allReduceMin( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceMax": - tensors[node.output[0]] = self.handler.allReduceMax( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllReduceAvg": - tensors[node.output[0]] = self.handler.allReduceAvg( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "AllGather": - for name, tensor in zip( - node.output, - self.handler.allGather( - tensors[node.input[0]], - None, - len(node.output), - ), - ): - tensors[name] = tensor - elif node.op_type == "Broadcast": - tensors[node.output[0]] = self.handler.broadcast( - tensors[node.input[0]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "root"), - 0, - ), - ) - elif node.op_type == "Send": - source = next( - (attr.i for attr in node.attribute if attr.name == "source"), - 0, - ) - destination = next( - ( - attr.i - for attr in node.attribute - if attr.name == "destination" - ), - 0, + != 0 ) - self.handler.send( - tensors[node.input[0]], - source, - destination, - None, - ) - elif node.op_type == "Recv": - source = next( - (attr.i for attr in node.attribute if attr.name == "source"), - 0, - ) - destination = next( - ( - attr.i - for attr in node.attribute - if attr.name == "destination" - ), - 0, - ) - - for attr in node.attribute: - if attr.name == "shape": - shapeBasic = attr.ints - shape = [] - for item in shapeBasic: - shape.append(item) - - for attr in node.attribute: - if attr.name == "dataType": - outputType = attr.i - tensors[node.output[0]] = self.handler.recv( - tensors.get(node.output[0]), - source, - destination, - shape, - outputType, - None, - ) - elif node.op_type == "Expand": - shape = _parse_data(data[node.input[1]]) - tensors[node.output[0]] = self.handler.expand( + tensors[node.output[0]] = self.handler.reduceSum( tensors[node.input[0]], tensors.get(node.output[0]), - shape, - ) - elif node.op_type == "Erf": - tensors[node.output[0]] = self.handler.erf( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Where": - tensors[node.output[0]] = self.handler.where( - tensors[node.input[1]], - tensors[node.input[2]], - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Constant": - output_name = node.output[0] - attributes = _parse_attribute(node) - tensor = attributes["value"] - dims = [d for d in tensor.dims] - tensors[output_name] = self.handler.tensor(dims, tensor.data_type) - data[output_name] = tensor - tensors[output_name].set_weight() - elif node.op_type == "DynamicQuantizeLinear": - for name, tensor in zip( - node.output, - self.handler.dynamicQuantizeLinear( - tensors[node.input[0]], None - ), - ): - tensors[name] = tensor - elif node.op_type == "DequantizeLinear": - (inputX, inputScale) = (tensors[node.input[i]] for i in [0, 1]) - inputZeroPoint = ( - None if len(node.input) < 3 else tensors[node.input[2]] - ) - output = tensors.get(node.output[0]) - axis = next( - (attr.i for attr in node.attribute if attr.name == "axis"), - 0, - ) - tensors[node.output[0]] = self.handler.dequantizeLinear( - inputX, - inputScale, - output, - inputZeroPoint, axis, + keepdims, ) - elif node.op_type == "MatMulInteger": - tensors[node.output[0]] = self.handler.matmulInteger( + elif node.op_type == "AllReduceSum": + tensors[node.output[0]] = self.handler.allReduceSum( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceProd": + tensors[node.output[0]] = self.handler.allReduceProd( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceMin": + tensors[node.output[0]] = self.handler.allReduceMin( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceMax": + tensors[node.output[0]] = self.handler.allReduceMax( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllReduceAvg": + tensors[node.output[0]] = self.handler.allReduceAvg( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "AllGather": + for name, tensor in zip( + node.output, + self.handler.allGather( tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - tensors[node.input[2]] if len(node.input) > 2 else None, - tensors[node.input[3]] if len(node.input) > 3 else None, - ) - else: - raise Exception('Unsupported operator "{}"'.format(node.op_type)) - new_node_name.append(node.name) - # update the node_list - node_list = list(set(node_name) - set(new_node_name)) + None, + len(node.output), + ), + ): + tensors[name] = tensor + elif node.op_type == "Broadcast": + tensors[node.output[0]] = self.handler.broadcast( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "root"), + 0, + ), + ) + elif node.op_type == "Send": + source = next( + (attr.i for attr in node.attribute if attr.name == "source"), + 0, + ) + destination = next( + (attr.i for attr in node.attribute if attr.name == "destination"), + 0, + ) + + self.handler.send( + tensors[node.input[0]], + source, + destination, + None, + ) + elif node.op_type == "Recv": + source = next( + (attr.i for attr in node.attribute if attr.name == "source"), + 0, + ) + destination = next( + (attr.i for attr in node.attribute if attr.name == "destination"), + 0, + ) + + for attr in node.attribute: + if attr.name == "shape": + shapeBasic = attr.ints + shape = [] + for item in shapeBasic: + shape.append(item) + + for attr in node.attribute: + if attr.name == "dataType": + outputType = attr.i + tensors[node.output[0]] = self.handler.recv( + tensors.get(node.output[0]), + source, + destination, + shape, + outputType, + None, + ) + elif node.op_type == "Expand": + shape = _parse_data(data[node.input[1]]) + tensors[node.output[0]] = self.handler.expand( + tensors[node.input[0]], + tensors.get(node.output[0]), + shape, + ) + elif node.op_type == "Erf": + tensors[node.output[0]] = self.handler.erf( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Where": + tensors[node.output[0]] = self.handler.where( + tensors[node.input[1]], + tensors[node.input[2]], + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Constant": + output_name = node.output[0] + attributes = _parse_attribute(node) + tensor = attributes["value"] + dims = [d for d in tensor.dims] + tensors[output_name] = self.handler.tensor(dims, tensor.data_type) + data[output_name] = tensor + tensors[output_name].set_weight() + elif node.op_type == "DynamicQuantizeLinear": + for name, tensor in zip( + node.output, + self.handler.dynamicQuantizeLinear(tensors[node.input[0]], None), + ): + tensors[name] = tensor + elif node.op_type == "DequantizeLinear": + (inputX, inputScale) = (tensors[node.input[i]] for i in [0, 1]) + inputZeroPoint = None if len(node.input) < 3 else tensors[node.input[2]] + output = tensors.get(node.output[0]) + axis = next( + (attr.i for attr in node.attribute if attr.name == "axis"), + 0, + ) + tensors[node.output[0]] = self.handler.dequantizeLinear( + inputX, + inputScale, + output, + inputZeroPoint, + axis, + ) + elif node.op_type == "MatMulInteger": + tensors[node.output[0]] = self.handler.matmulInteger( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + tensors[node.input[2]] if len(node.input) > 2 else None, + tensors[node.input[3]] if len(node.input) > 3 else None, + ) + else: + raise Exception('Unsupported operator "{}"'.format(node.op_type)) ################################ # Allocate memory space for data @@ -1361,10 +1343,3 @@ def _parse_data_fp16(tensor: TensorProto): def _take_shape_dim(shape: TensorShapeProto) -> List[int]: return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] - - -def _analyse_node(node: NodeProto, tensors) -> bool: - for i in node.input: - if i not in tensors: - return True - return False