From 77fd137dcbeffc8d370ccdf58376cc13417981e6 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Thu, 25 Apr 2024 16:40:15 +0800 Subject: [PATCH] fix: support batchnorm cudnn 2 dimension input --- src/kernels/cuda/batch_norm.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/kernels/cuda/batch_norm.cc b/src/kernels/cuda/batch_norm.cc index b083ad9c..2e409a32 100644 --- a/src/kernels/cuda/batch_norm.cc +++ b/src/kernels/cuda/batch_norm.cc @@ -18,9 +18,16 @@ class BatchNormCudnn : public CudaKernelWithoutConfig { void *const scaleData = (op->getInputs(3)->getRawDataPtr()); void *const biasData = (op->getInputs(4)->getRawDataPtr()); - auto dims = op->getInputs(0)->getDims(); // Only 4D and 5D tensors are supported by // cudnnBatchNormalizationForwardInference + if (auto dims = op->getInputs(0)->getDims(); dims.size() < 4) { + auto dims_t = dims; + for (size_t i = dims_t.size(); i < 4; ++i) { + dims_t.push_back(1); + } + op->getInputs(0)->setShape(dims_t); + } + auto dims = op->getInputs(0)->getDims(); IT_ASSERT(dims.size() == 4); int dimArray[4], strideArray[4], dimPArray[4], stridePArray[4];