Compare commits

..

2 Commits

Author SHA1 Message Date
crapromer 20f651b1d3 implement instance norm in front 2024-05-08 17:44:54 +08:00
zhangyunze 3001274969 fix: onnx resize op input is none bug 2024-05-07 17:07:22 +08:00
1 changed files with 59 additions and 7 deletions

View File

@ -85,7 +85,7 @@ class OnnxStub:
while len(sorted_nodes) < len(model.graph.node): while len(sorted_nodes) < len(model.graph.node):
updated = False updated = False
for i, node in enumerate(model.graph.node): for i, node in enumerate(model.graph.node):
if all(t in known_edge for t in node.input): if all(t in known_edge or t == "" for t in node.input):
node.name = str(len(sorted_nodes)) + "_" + node.name node.name = str(len(sorted_nodes)) + "_" + node.name
sorted_nodes.append(i) sorted_nodes.append(i)
known_edge.update(node.output) known_edge.update(node.output)
@ -280,6 +280,46 @@ class OnnxStub:
axis, axis,
stash_type, stash_type,
) )
elif node.op_type == "InstanceNormalization":
shape = tensors[node.input[0]].shape()
if len(shape) == 4:
(input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2])
output = tensors.get(node.output[0])
attributes = _parse_attribute(
node, {"momentum": 1, "training_mode": 0}
)
(momentum,epsilon,training) = (attributes[name] for name in
["momentum", "epsilon", "training_mode"])
reshape = "{}-reshape".format(node.output[0])
unreshape = "{}-unreshape".format(node.output[0])
shape = tensors[node.input[0]].shape()
mean = self.handler.tensor(
[shape[1]], 1
)
var = self.handler.tensor(
[shape[1]], 1
)
tensors[reshape] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(reshape),
(1,shape[0]*shape[1],shape[2],shape[3])
)
tensors[unreshape] = self.handler.batchNormalization(
tensors[reshape],
tensors.get(unreshape),
mean,
var,
scale,
bias,
momentum,
epsilon,
training != 0
)
tensors[node.output[0]] = self.handler.reshape(
tensors[unreshape],
tensors.get(node.output[0]),
shape
)
elif node.op_type == "RMSNorm": elif node.op_type == "RMSNorm":
tensors[node.output[0]] = self.handler.RMSNorm( tensors[node.output[0]] = self.handler.RMSNorm(
tensors[node.input[0]], tensors[node.input[0]],
@ -601,15 +641,15 @@ class OnnxStub:
"nearest_mode", "nearest_mode",
] ]
) )
if len(node.input) > 1: if len(node.input) > 1 and node.input[1] in data:
roiVal = _parse_data(data[node.input[1]]) roiVal = _parse_data(data[node.input[1]])
else: else:
roiVal = [] roiVal = []
if len(node.input) > 2: if len(node.input) > 2 and node.input[2] in data:
scalesVal = _parse_data(data[node.input[2]]) scalesVal = _parse_data(data[node.input[2]])
else: else:
scalesVal = [] scalesVal = []
if len(node.input) > 3: if len(node.input) > 3 and node.input[3] in data:
sizesVal = _parse_data(data[node.input[3]]) sizesVal = _parse_data(data[node.input[3]])
else: else:
sizesVal = [] sizesVal = []
@ -617,9 +657,21 @@ class OnnxStub:
tensors[node.input[0]], tensors[node.input[0]],
output, output,
axes, axes,
tensors[node.input[3]] if len(node.input) > 3 else None, (
tensors[node.input[2]] if len(node.input) > 2 else None, tensors[node.input[3]]
tensors[node.input[1]] if len(node.input) > 1 else None, if len(node.input) > 3 and node.input[3] != ""
else None
),
(
tensors[node.input[2]]
if len(node.input) > 2 and node.input[2] != ""
else None
),
(
tensors[node.input[1]]
if len(node.input) > 1 and node.input[1] != ""
else None
),
sizesVal, sizesVal,
scalesVal, scalesVal,
roiVal, roiVal,