From 514666591e145586b8b30b7829cf69e75f217ba4 Mon Sep 17 00:00:00 2001 From: wanghailu Date: Fri, 24 Feb 2023 13:55:53 +0800 Subject: [PATCH] add batch_norm --- src/kernels/cuda/batch_norm.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/kernels/cuda/batch_norm.cc b/src/kernels/cuda/batch_norm.cc index 7ca75ba6..b150aaa5 100644 --- a/src/kernels/cuda/batch_norm.cc +++ b/src/kernels/cuda/batch_norm.cc @@ -28,9 +28,11 @@ class BatchNormCudnn : public CudaKernelWithoutConfig { for (size_t i = 0; i < dims.size(); ++i) { dimArray[i] = dims[i]; strideArray[i] = op->getInputs(0)->getStride()[i]; - dimPArray[i] = op->getInputs(1)->getDims()[i]; - stridePArray[i] = op->getInputs(1)->getStride()[i]; + dimPArray[i] = 1; + stridePArray[i] = 1; } + dimPArray[1] = op->getInputs(0)->getDims()[1]; + stridePArray[1] = op->getInputs(0)->getStride()[1]; // get inputs cudnnTensorDescriptor_t inDesc; checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));