add batch_norm

This commit is contained in:
wanghailu 2023-02-24 13:55:53 +08:00 committed by YdrMaster
parent 3d122aebfe
commit 514666591e
1 changed files with 4 additions and 2 deletions

View File

@ -28,9 +28,11 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
dimArray[i] = dims[i]; dimArray[i] = dims[i];
strideArray[i] = op->getInputs(0)->getStride()[i]; strideArray[i] = op->getInputs(0)->getStride()[i];
dimPArray[i] = op->getInputs(1)->getDims()[i]; dimPArray[i] = 1;
stridePArray[i] = op->getInputs(1)->getStride()[i]; stridePArray[i] = 1;
} }
dimPArray[1] = op->getInputs(0)->getDims()[1];
stridePArray[1] = op->getInputs(0)->getStride()[1];
// get inputs // get inputs
cudnnTensorDescriptor_t inDesc; cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));