forked from jiuyuan/InfiniTensor
Compare commits
No commits in common. "instance_norm" and "master" have entirely different histories.
instance_n
...
master
|
@ -85,7 +85,7 @@ class OnnxStub:
|
||||||
while len(sorted_nodes) < len(model.graph.node):
|
while len(sorted_nodes) < len(model.graph.node):
|
||||||
updated = False
|
updated = False
|
||||||
for i, node in enumerate(model.graph.node):
|
for i, node in enumerate(model.graph.node):
|
||||||
if all(t in known_edge or t == "" for t in node.input):
|
if all(t in known_edge for t in node.input):
|
||||||
node.name = str(len(sorted_nodes)) + "_" + node.name
|
node.name = str(len(sorted_nodes)) + "_" + node.name
|
||||||
sorted_nodes.append(i)
|
sorted_nodes.append(i)
|
||||||
known_edge.update(node.output)
|
known_edge.update(node.output)
|
||||||
|
@ -280,46 +280,6 @@ class OnnxStub:
|
||||||
axis,
|
axis,
|
||||||
stash_type,
|
stash_type,
|
||||||
)
|
)
|
||||||
elif node.op_type == "InstanceNormalization":
|
|
||||||
shape = tensors[node.input[0]].shape()
|
|
||||||
if len(shape) == 4:
|
|
||||||
(input, scale, bias) = (tensors[node.input[i]] for i in [0, 1, 2])
|
|
||||||
output = tensors.get(node.output[0])
|
|
||||||
attributes = _parse_attribute(
|
|
||||||
node, {"momentum": 1, "training_mode": 0}
|
|
||||||
)
|
|
||||||
(momentum,epsilon,training) = (attributes[name] for name in
|
|
||||||
["momentum", "epsilon", "training_mode"])
|
|
||||||
reshape = "{}-reshape".format(node.output[0])
|
|
||||||
unreshape = "{}-unreshape".format(node.output[0])
|
|
||||||
shape = tensors[node.input[0]].shape()
|
|
||||||
mean = self.handler.tensor(
|
|
||||||
[shape[1]], 1
|
|
||||||
)
|
|
||||||
var = self.handler.tensor(
|
|
||||||
[shape[1]], 1
|
|
||||||
)
|
|
||||||
tensors[reshape] = self.handler.reshape(
|
|
||||||
tensors[node.input[0]],
|
|
||||||
tensors.get(reshape),
|
|
||||||
(1,shape[0]*shape[1],shape[2],shape[3])
|
|
||||||
)
|
|
||||||
tensors[unreshape] = self.handler.batchNormalization(
|
|
||||||
tensors[reshape],
|
|
||||||
tensors.get(unreshape),
|
|
||||||
mean,
|
|
||||||
var,
|
|
||||||
scale,
|
|
||||||
bias,
|
|
||||||
momentum,
|
|
||||||
epsilon,
|
|
||||||
training != 0
|
|
||||||
)
|
|
||||||
tensors[node.output[0]] = self.handler.reshape(
|
|
||||||
tensors[unreshape],
|
|
||||||
tensors.get(node.output[0]),
|
|
||||||
shape
|
|
||||||
)
|
|
||||||
elif node.op_type == "RMSNorm":
|
elif node.op_type == "RMSNorm":
|
||||||
tensors[node.output[0]] = self.handler.RMSNorm(
|
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -641,15 +601,15 @@ class OnnxStub:
|
||||||
"nearest_mode",
|
"nearest_mode",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if len(node.input) > 1 and node.input[1] in data:
|
if len(node.input) > 1:
|
||||||
roiVal = _parse_data(data[node.input[1]])
|
roiVal = _parse_data(data[node.input[1]])
|
||||||
else:
|
else:
|
||||||
roiVal = []
|
roiVal = []
|
||||||
if len(node.input) > 2 and node.input[2] in data:
|
if len(node.input) > 2:
|
||||||
scalesVal = _parse_data(data[node.input[2]])
|
scalesVal = _parse_data(data[node.input[2]])
|
||||||
else:
|
else:
|
||||||
scalesVal = []
|
scalesVal = []
|
||||||
if len(node.input) > 3 and node.input[3] in data:
|
if len(node.input) > 3:
|
||||||
sizesVal = _parse_data(data[node.input[3]])
|
sizesVal = _parse_data(data[node.input[3]])
|
||||||
else:
|
else:
|
||||||
sizesVal = []
|
sizesVal = []
|
||||||
|
@ -657,21 +617,9 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
output,
|
output,
|
||||||
axes,
|
axes,
|
||||||
(
|
tensors[node.input[3]] if len(node.input) > 3 else None,
|
||||||
tensors[node.input[3]]
|
tensors[node.input[2]] if len(node.input) > 2 else None,
|
||||||
if len(node.input) > 3 and node.input[3] != ""
|
tensors[node.input[1]] if len(node.input) > 1 else None,
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
tensors[node.input[2]]
|
|
||||||
if len(node.input) > 2 and node.input[2] != ""
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
tensors[node.input[1]]
|
|
||||||
if len(node.input) > 1 and node.input[1] != ""
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
sizesVal,
|
sizesVal,
|
||||||
scalesVal,
|
scalesVal,
|
||||||
roiVal,
|
roiVal,
|
||||||
|
|
Loading…
Reference in New Issue