fix: support batchnorm cudnn 2 dimension input

This commit is contained in:
zhangyunze 2024-04-25 16:40:15 +08:00
parent c6de91ee82
commit 77fd137dcb
1 changed files with 8 additions and 1 deletions

View File

@ -18,9 +18,16 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
void *const scaleData = (op->getInputs(3)->getRawDataPtr<void *>()); void *const scaleData = (op->getInputs(3)->getRawDataPtr<void *>());
void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>()); void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>());
auto dims = op->getInputs(0)->getDims();
// Only 4D and 5D tensors are supported by // Only 4D and 5D tensors are supported by
// cudnnBatchNormalizationForwardInference // cudnnBatchNormalizationForwardInference
if (auto dims = op->getInputs(0)->getDims(); dims.size() < 4) {
auto dims_t = dims;
for (size_t i = dims_t.size(); i < 4; ++i) {
dims_t.push_back(1);
}
op->getInputs(0)->setShape(dims_t);
}
auto dims = op->getInputs(0)->getDims();
IT_ASSERT(dims.size() == 4); IT_ASSERT(dims.size() == 4);
int dimArray[4], strideArray[4], dimPArray[4], stridePArray[4]; int dimArray[4], strideArray[4], dimPArray[4], stridePArray[4];