diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 1f839256..0db20fc9 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -66,10 +66,10 @@ class TestStringMethods(unittest.TestCase): def test_batch_norm(self): x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2]) - scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1]) - b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1]) - mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1]) - var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1]) + scale = make_tensor_value_info("scale", TensorProto.FLOAT, [3]) + b = make_tensor_value_info("b", TensorProto.FLOAT, [3]) + mean = make_tensor_value_info("mean", TensorProto.FLOAT, [3]) + var = make_tensor_value_info("var", TensorProto.FLOAT, [3]) y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2]) batch_norm = make_node( "BatchNormalization", diff --git a/src/operators/batch_norm.cc b/src/operators/batch_norm.cc index b9be9425..b744df5a 100644 --- a/src/operators/batch_norm.cc +++ b/src/operators/batch_norm.cc @@ -19,12 +19,9 @@ BatchNormObj::inferShape(const TensorVec &inputs) const { auto var = inputs[2]; auto scale = inputs[3]; auto bias = inputs[4]; - if (input->getDims().size() < 2) - return {}; - Shape dims(input->getDims().size(), 1); - dims[1] = input->getDims()[1]; // - if (mean->getDims() != dims || var->getDims() != dims || - scale->getDims() != dims || bias->getDims() != dims) + auto c = std::vector{input->getDims()[1]}; + if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c || + bias->getDims() != c) return {}; return {{input->getDims()}}; } diff --git a/test/operators/test_batch_norm.cc b/test/operators/test_batch_norm.cc index d904edd1..438db6e6 100644 --- a/test/operators/test_batch_norm.cc +++ b/test/operators/test_batch_norm.cc @@ -9,10 +9,10 @@ TEST(BatchNorm, ShapeInference) { { Graph g = make_ref(cpuRuntime); Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32); - Tensor mean = g->addTensor({1, 3, 1, 1}, DataType::Float32); - Tensor var = g->addTensor({1, 3, 1, 1}, DataType::Float32); - Tensor scaler = g->addTensor({1, 3, 1, 1}, DataType::Float32); - Tensor bias = g->addTensor({1, 3, 1, 1}, DataType::Float32); + Tensor mean = g->addTensor({3}, DataType::Float32); + Tensor var = g->addTensor({3}, DataType::Float32); + Tensor scaler = g->addTensor({3}, DataType::Float32); + Tensor bias = g->addTensor({3}, DataType::Float32); auto op = g->addOp(i, nullptr, mean, var, scaler, bias, 0.9, 1e-5); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));