forked from jiuyuan/InfiniTensor
implement instance norm in front
This commit is contained in:
parent
3001274969
commit
20f651b1d3
|
@ -280,6 +280,46 @@ class OnnxStub:
|
||||||
axis,
|
axis,
|
||||||
stash_type,
|
stash_type,
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "InstanceNormalization":
|
||||||
|
shape = tensors[node.input[0]].shape()
|
||||||
|
if len(shape) == 4:
|
||||||
|
(input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2])
|
||||||
|
output = tensors.get(node.output[0])
|
||||||
|
attributes = _parse_attribute(
|
||||||
|
node, {"momentum": 1, "training_mode": 0}
|
||||||
|
)
|
||||||
|
(momentum,epsilon,training) = (attributes[name] for name in
|
||||||
|
["momentum", "epsilon", "training_mode"])
|
||||||
|
reshape = "{}-reshape".format(node.output[0])
|
||||||
|
unreshape = "{}-unreshape".format(node.output[0])
|
||||||
|
shape = tensors[node.input[0]].shape()
|
||||||
|
mean = self.handler.tensor(
|
||||||
|
[shape[1]], 1
|
||||||
|
)
|
||||||
|
var = self.handler.tensor(
|
||||||
|
[shape[1]], 1
|
||||||
|
)
|
||||||
|
tensors[reshape] = self.handler.reshape(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(reshape),
|
||||||
|
(1,shape[0]*shape[1],shape[2],shape[3])
|
||||||
|
)
|
||||||
|
tensors[unreshape] = self.handler.batchNormalization(
|
||||||
|
tensors[reshape],
|
||||||
|
tensors.get(unreshape),
|
||||||
|
mean,
|
||||||
|
var,
|
||||||
|
scale,
|
||||||
|
bias,
|
||||||
|
momentum,
|
||||||
|
epsilon,
|
||||||
|
training != 0
|
||||||
|
)
|
||||||
|
tensors[node.output[0]] = self.handler.reshape(
|
||||||
|
tensors[unreshape],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
shape
|
||||||
|
)
|
||||||
elif node.op_type == "RMSNorm":
|
elif node.op_type == "RMSNorm":
|
||||||
tensors[node.output[0]] = self.handler.RMSNorm(
|
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
|
Loading…
Reference in New Issue