From a27391fcdcefcff3b355f3d42faf0ab7dbc40cf5 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 23 Feb 2023 11:48:28 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3=20batchNorm=20?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - onnx 和 pytorch 认为 batchNorm 的 4 个参数是 [c] 形状的,cuDNN 可能认为是 [1,c,1,...]。 优化已改为 [c],但 cuDNN 推理没有改; Signed-off-by: YdrMaster --- pyinfinitensor/tests/test_onnx.py | 8 ++++---- src/operators/batch_norm.cc | 9 +++------ test/operators/test_batch_norm.cc | 8 ++++---- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 1f839256..0db20fc9 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -66,10 +66,10 @@ class TestStringMethods(unittest.TestCase): def test_batch_norm(self): x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2]) - scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1]) - b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1]) - mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1]) - var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1]) + scale = make_tensor_value_info("scale", TensorProto.FLOAT, [3]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [3]) + mean = make_tensor_value_info("mean", TensorProto.FLOAT, [3]) + var = make_tensor_value_info("var", TensorProto.FLOAT, [3]) y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2]) batch_norm = make_node( "BatchNormalization", diff --git a/src/operators/batch_norm.cc b/src/operators/batch_norm.cc index b9be9425..b744df5a 100644 --- a/src/operators/batch_norm.cc +++ b/src/operators/batch_norm.cc @@ -19,12 +19,9 @@ BatchNormObj::inferShape(const TensorVec &inputs) const { auto var = inputs[2]; auto scale = inputs[3]; auto bias = inputs[4]; - if (input->getDims().size() < 2) - return {}; - Shape dims(input->getDims().size(), 1); - dims[1] = input->getDims()[1]; // - if (mean->getDims() != dims || var->getDims() != dims || - scale->getDims() != dims || bias->getDims() != dims) + auto c = std::vector{input->getDims()[1]}; + if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c || + bias->getDims() != c) return {}; return {{input->getDims()}}; } diff --git a/test/operators/test_batch_norm.cc b/test/operators/test_batch_norm.cc index d904edd1..438db6e6 100644 --- a/test/operators/test_batch_norm.cc +++ b/test/operators/test_batch_norm.cc @@ -9,10 +9,10 @@ TEST(BatchNorm, ShapeInference) { { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32); - Tensor mean = g->addTensor({1, 3, 1, 1}, DataType::Float32); - Tensor var = g->addTensor({1, 3, 1, 1}, DataType::Float32); - Tensor scaler = g->addTensor({1, 3, 1, 1}, DataType::Float32); - Tensor bias = g->addTensor({1, 3, 1, 1}, DataType::Float32); + Tensor mean = g->addTensor({3}, DataType::Float32); + Tensor var = g->addTensor({3}, DataType::Float32); + Tensor scaler = g->addTensor({3}, DataType::Float32); + Tensor bias = g->addTensor({3}, DataType::Float32); auto op = g->addOp(i, nullptr, mean, var, scaler, bias, 0.9, 1e-5); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));