From 20f651b1d301a90641a8f878e9cd485dac14a317 Mon Sep 17 00:00:00 2001 From: crapromer Date: Wed, 8 May 2024 17:44:54 +0800 Subject: [PATCH] implement instance norm in front --- pyinfinitensor/src/pyinfinitensor/onnx.py | 40 +++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index c5764af8..69d6ee32 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -280,6 +280,46 @@ class OnnxStub: axis, stash_type, ) + elif node.op_type == "InstanceNormalization": + shape = tensors[node.input[0]].shape() + if len(shape) == 4: + (input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2]) + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"momentum": 1, "training_mode": 0} + ) + (momentum,epsilon,training) = (attributes[name] for name in + ["momentum", "epsilon", "training_mode"]) + reshape = "{}-reshape".format(node.output[0]) + unreshape = "{}-unreshape".format(node.output[0]) + shape = tensors[node.input[0]].shape() + mean = self.handler.tensor( + [shape[1]], 1 + ) + var = self.handler.tensor( + [shape[1]], 1 + ) + tensors[reshape] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(reshape), + (1,shape[0]*shape[1],shape[2],shape[3]) + ) + tensors[unreshape] = self.handler.batchNormalization( + tensors[reshape], + tensors.get(unreshape), + mean, + var, + scale, + bias, + momentum, + epsilon, + training != 0 + ) + tensors[node.output[0]] = self.handler.reshape( + tensors[unreshape], + tensors.get(node.output[0]), + shape + ) elif node.op_type == "RMSNorm": tensors[node.output[0]] = self.handler.RMSNorm( tensors[node.input[0]],