fix: 修正 batchNorm 实现

- onnx 和 pytorch 认为 batchNorm 的 4 个参数是 [c] 形状的,cuDNN 可能认为是 [1,c,1,...]。
优化已改为 [c],但 cuDNN 推理没有改;

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-23 11:48:28 +08:00
parent dd5d091dbc
commit a27391fcdc
3 changed files with 11 additions and 14 deletions

View File

@ -66,10 +66,10 @@ class TestStringMethods(unittest.TestCase):
def test_batch_norm(self):
x = make_tensor_value_info("x", TensorProto.UINT32, [1, 3, 2, 2])
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [1, 3, 1, 1])
b = make_tensor_value_info("b", TensorProto.FLOAT, [1, 3, 1, 1])
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [1, 3, 1, 1])
var = make_tensor_value_info("var", TensorProto.FLOAT, [1, 3, 1, 1])
scale = make_tensor_value_info("scale", TensorProto.FLOAT, [3])
b = make_tensor_value_info("b", TensorProto.FLOAT, [3])
mean = make_tensor_value_info("mean", TensorProto.FLOAT, [3])
var = make_tensor_value_info("var", TensorProto.FLOAT, [3])
y = make_tensor_value_info("y", TensorProto.UINT32, [1, 3, 2, 2])
batch_norm = make_node(
"BatchNormalization",

View File

@ -19,12 +19,9 @@ BatchNormObj::inferShape(const TensorVec &inputs) const {
auto var = inputs[2];
auto scale = inputs[3];
auto bias = inputs[4];
if (input->getDims().size() < 2)
return {};
Shape dims(input->getDims().size(), 1);
dims[1] = input->getDims()[1]; //
if (mean->getDims() != dims || var->getDims() != dims ||
scale->getDims() != dims || bias->getDims() != dims)
auto c = std::vector<int>{input->getDims()[1]};
if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c ||
bias->getDims() != c)
return {};
return {{input->getDims()}};
}

View File

@ -9,10 +9,10 @@ TEST(BatchNorm, ShapeInference) {
{
Graph g = make_ref<GraphObj>(cpuRuntime);
Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
Tensor mean = g->addTensor({1, 3, 1, 1}, DataType::Float32);
Tensor var = g->addTensor({1, 3, 1, 1}, DataType::Float32);
Tensor scaler = g->addTensor({1, 3, 1, 1}, DataType::Float32);
Tensor bias = g->addTensor({1, 3, 1, 1}, DataType::Float32);
Tensor mean = g->addTensor({3}, DataType::Float32);
Tensor var = g->addTensor({3}, DataType::Float32);
Tensor scaler = g->addTensor({3}, DataType::Float32);
Tensor bias = g->addTensor({3}, DataType::Float32);
auto op = g->addOp<BatchNormObj>(i, nullptr, mean, var, scaler, bias,
0.9, 1e-5);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 2}));