forked from jiuyuan/InfiniTensor
fix: 修正 batchNorm 实现
- onnx 和 pytorch 认为 batchNorm 的 4 个参数是 [c] 形状的,cuDNN 可能认为是 [1,c,1,...]。 优化已改为 [c],但 cuDNN 推理没有改; Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
dd5d091dbc
commit
a27391fcdc
|
@ -66,10 +66,10 @@ class TestStringMethods(unittest.TestCase):
|
||||||
|
|
||||||
def test_batch_norm(self):
|
def test_batch_norm(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
|
||||||
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
|
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [3])
|
||||||
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1])
|
b = make_tensor_value_info("b", TensorProto.FLOAT, [3])
|
||||||
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1])
|
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [3])
|
||||||
var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1])
|
var = make_tensor_value_info("var", TensorProto.FLOAT, [3])
|
||||||
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2])
|
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2])
|
||||||
batch_norm = make_node(
|
batch_norm = make_node(
|
||||||
"BatchNormalization",
|
"BatchNormalization",
|
||||||
|
|
|
@ -19,12 +19,9 @@ BatchNormObj::inferShape(const TensorVec &inputs) const {
|
||||||
auto var = inputs[2];
|
auto var = inputs[2];
|
||||||
auto scale = inputs[3];
|
auto scale = inputs[3];
|
||||||
auto bias = inputs[4];
|
auto bias = inputs[4];
|
||||||
if (input->getDims().size() < 2)
|
auto c = std::vector<int>{input->getDims()[1]};
|
||||||
return {};
|
if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c ||
|
||||||
Shape dims(input->getDims().size(), 1);
|
bias->getDims() != c)
|
||||||
dims[1] = input->getDims()[1]; //
|
|
||||||
if (mean->getDims() != dims || var->getDims() != dims ||
|
|
||||||
scale->getDims() != dims || bias->getDims() != dims)
|
|
||||||
return {};
|
return {};
|
||||||
return {{input->getDims()}};
|
return {{input->getDims()}};
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,10 +9,10 @@ TEST(BatchNorm, ShapeInference) {
|
||||||
{
|
{
|
||||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
|
Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
|
||||||
Tensor mean = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
Tensor mean = g->addTensor({3}, DataType::Float32);
|
||||||
Tensor var = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
Tensor var = g->addTensor({3}, DataType::Float32);
|
||||||
Tensor scaler = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
Tensor scaler = g->addTensor({3}, DataType::Float32);
|
||||||
Tensor bias = g->addTensor({1, 3, 1, 1}, DataType::Float32);
|
Tensor bias = g->addTensor({3}, DataType::Float32);
|
||||||
auto op = g->addOp<BatchNormObj>(i, nullptr, mean, var, scaler, bias,
|
auto op = g->addOp<BatchNormObj>(i, nullptr, mean, var, scaler, bias,
|
||||||
0.9, 1e-5);
|
0.9, 1e-5);
|
||||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));
|
||||||
|
|
Loading…
Reference in New Issue