Compare commits

...

2 Commits

Author SHA1 Message Date
crapromer 20f651b1d3 implement instance norm in front 2024-05-08 17:44:54 +08:00
zhangyunze 3001274969 fix: onnx resize op input is none bug 2024-05-07 17:07:22 +08:00
1 changed files with 59 additions and 7 deletions

View File

@ -85,7 +85,7 @@ class OnnxStub:
while len(sorted_nodes) < len(model.graph.node):
updated = False
for i, node in enumerate(model.graph.node):
if all(t in known_edge for t in node.input):
if all(t in known_edge or t == "" for t in node.input):
node.name = str(len(sorted_nodes)) + "_" + node.name
sorted_nodes.append(i)
known_edge.update(node.output)
@ -280,6 +280,46 @@ class OnnxStub:
axis,
stash_type,
)
elif node.op_type == "InstanceNormalization":
shape = tensors[node.input[0]].shape()
if len(shape) == 4:
(input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2])
output = tensors.get(node.output[0])
attributes = _parse_attribute(
node, {"momentum": 1, "training_mode": 0}
)
(momentum,epsilon,training) = (attributes[name] for name in
["momentum", "epsilon", "training_mode"])
reshape = "{}-reshape".format(node.output[0])
unreshape = "{}-unreshape".format(node.output[0])
shape = tensors[node.input[0]].shape()
mean = self.handler.tensor(
[shape[1]], 1
)
var = self.handler.tensor(
[shape[1]], 1
)
tensors[reshape] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(reshape),
(1,shape[0]*shape[1],shape[2],shape[3])
)
tensors[unreshape] = self.handler.batchNormalization(
tensors[reshape],
tensors.get(unreshape),
mean,
var,
scale,
bias,
momentum,
epsilon,
training != 0
)
tensors[node.output[0]] = self.handler.reshape(
tensors[unreshape],
tensors.get(node.output[0]),
shape
)
elif node.op_type == "RMSNorm":
tensors[node.output[0]] = self.handler.RMSNorm(
tensors[node.input[0]],
@ -601,15 +641,15 @@ class OnnxStub:
"nearest_mode",
]
)
if len(node.input) > 1:
if len(node.input) > 1 and node.input[1] in data:
roiVal = _parse_data(data[node.input[1]])
else:
roiVal = []
if len(node.input) > 2:
if len(node.input) > 2 and node.input[2] in data:
scalesVal = _parse_data(data[node.input[2]])
else:
scalesVal = []
if len(node.input) > 3:
if len(node.input) > 3 and node.input[3] in data:
sizesVal = _parse_data(data[node.input[3]])
else:
sizesVal = []
@ -617,9 +657,21 @@ class OnnxStub:
tensors[node.input[0]],
output,
axes,
tensors[node.input[3]] if len(node.input) > 3 else None,
tensors[node.input[2]] if len(node.input) > 2 else None,
tensors[node.input[1]] if len(node.input) > 1 else None,
(
tensors[node.input[3]]
if len(node.input) > 3 and node.input[3] != ""
else None
),
(
tensors[node.input[2]]
if len(node.input) > 2 and node.input[2] != ""
else None
),
(
tensors[node.input[1]]
if len(node.input) > 1 and node.input[1] != ""
else None
),
sizesVal,
scalesVal,
roiVal,