diff --git a/src/kernels/cuda/batch_norm.cc b/src/kernels/cuda/batch_norm.cc index 7ca75ba6..b150aaa5 100644 --- a/src/kernels/cuda/batch_norm.cc +++ b/src/kernels/cuda/batch_norm.cc @@ -28,9 +28,11 @@ class BatchNormCudnn : public CudaKernelWithoutConfig { for (size_t i = 0; i < dims.size(); ++i) { dimArray[i] = dims[i]; strideArray[i] = op->getInputs(0)->getStride()[i]; - dimPArray[i] = op->getInputs(1)->getDims()[i]; - stridePArray[i] = op->getInputs(1)->getStride()[i]; + dimPArray[i] = 1; + stridePArray[i] = 1; } + dimPArray[1] = op->getInputs(0)->getDims()[1]; + stridePArray[1] = op->getInputs(0)->getStride()[1]; // get inputs cudnnTensorDescriptor_t inDesc; checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));