forked from jiuyuan/InfiniTensor
fix: support batchnorm cudnn 2 dimension input
This commit is contained in:
parent
c6de91ee82
commit
77fd137dcb
|
@ -18,9 +18,16 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
|||
void *const scaleData = (op->getInputs(3)->getRawDataPtr<void *>());
|
||||
void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
// Only 4D and 5D tensors are supported by
|
||||
// cudnnBatchNormalizationForwardInference
|
||||
if (auto dims = op->getInputs(0)->getDims(); dims.size() < 4) {
|
||||
auto dims_t = dims;
|
||||
for (size_t i = dims_t.size(); i < 4; ++i) {
|
||||
dims_t.push_back(1);
|
||||
}
|
||||
op->getInputs(0)->setShape(dims_t);
|
||||
}
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
IT_ASSERT(dims.size() == 4);
|
||||
|
||||
int dimArray[4], strideArray[4], dimPArray[4], stridePArray[4];
|
||||
|
|
Loading…
Reference in New Issue