implement instance norm in front

This commit is contained in:
crapromer 2024-05-08 17:44:54 +08:00
parent 3001274969
commit 20f651b1d3
1 changed files with 40 additions and 0 deletions

View File

@ -280,6 +280,46 @@ class OnnxStub:
axis,
stash_type,
)
elif node.op_type == "InstanceNormalization":
shape = tensors[node.input[0]].shape()
if len(shape) == 4:
(input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2])
output = tensors.get(node.output[0])
attributes = _parse_attribute(
node, {"momentum": 1, "training_mode": 0}
)
(momentum,epsilon,training) = (attributes[name] for name in
["momentum", "epsilon", "training_mode"])
reshape = "{}-reshape".format(node.output[0])
unreshape = "{}-unreshape".format(node.output[0])
shape = tensors[node.input[0]].shape()
mean = self.handler.tensor(
[shape[1]], 1
)
var = self.handler.tensor(
[shape[1]], 1
)
tensors[reshape] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(reshape),
(1,shape[0]*shape[1],shape[2],shape[3])
)
tensors[unreshape] = self.handler.batchNormalization(
tensors[reshape],
tensors.get(unreshape),
mean,
var,
scale,
bias,
momentum,
epsilon,
training != 0
)
tensors[node.output[0]] = self.handler.reshape(
tensors[unreshape],
tensors.get(node.output[0]),
shape
)
elif node.op_type == "RMSNorm":
tensors[node.output[0]] = self.handler.RMSNorm(
tensors[node.input[0]],