forked from jiuyuan/InfiniTensor
add batch_norm
This commit is contained in:
parent
3d122aebfe
commit
514666591e
|
@ -28,9 +28,11 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
||||||
for (size_t i = 0; i < dims.size(); ++i) {
|
for (size_t i = 0; i < dims.size(); ++i) {
|
||||||
dimArray[i] = dims[i];
|
dimArray[i] = dims[i];
|
||||||
strideArray[i] = op->getInputs(0)->getStride()[i];
|
strideArray[i] = op->getInputs(0)->getStride()[i];
|
||||||
dimPArray[i] = op->getInputs(1)->getDims()[i];
|
dimPArray[i] = 1;
|
||||||
stridePArray[i] = op->getInputs(1)->getStride()[i];
|
stridePArray[i] = 1;
|
||||||
}
|
}
|
||||||
|
dimPArray[1] = op->getInputs(0)->getDims()[1];
|
||||||
|
stridePArray[1] = op->getInputs(0)->getStride()[1];
|
||||||
// get inputs
|
// get inputs
|
||||||
cudnnTensorDescriptor_t inDesc;
|
cudnnTensorDescriptor_t inDesc;
|
||||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||||
|
|
Loading…
Reference in New Issue