From 384407421bbf1c29905c5062f647190e57cef688 Mon Sep 17 00:00:00 2001 From: constroy Li Date: Tue, 22 Aug 2023 14:21:59 +0800 Subject: [PATCH] cudnn activations support ND-Tensor (#116) * refine TensorObj::getStride * ActivationCudnn supports ND-Tensor --- include/core/tensor.h | 2 +- src/core/tensor.cc | 15 +++++++-------- src/kernels/cuda/unary.cc | 20 +++++++++++++------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/include/core/tensor.h b/include/core/tensor.h index 89f7d9be..fe0b536e 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -37,7 +37,7 @@ class TensorObj : public TensorBaseObj { Shape getDims() const { return shape; } size_t getRank() const { return shape.size(); } - vector getStride() const; + Shape getStride() const; size_t getOffset(const vector &ds) const; void dataMalloc(); UidBaseType getFuid() const { return fuid; } diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 77b3b49b..f52127e6 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -51,15 +51,14 @@ size_t TensorObj::getOffset(const vector &pos) const { return idx; } -vector TensorObj::getStride() const { - vector ret; - size_t stride = 1; - for (int i = shape.size() - 1; i >= 1; i--) { - ret.emplace(ret.begin(), stride); - stride *= shape.at(i); +Shape TensorObj::getStride() const { + Shape stride(getRank()); + ShapeElem p = 1; + for (auto i = getRank(); i > 0; --i) { + stride[i - 1] = p; + p = p * shape[i - 1]; } - ret.emplace(ret.begin(), stride); - return ret; + return stride; } void TensorObj::printData() const { diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index 317f45b8..abc8b0bc 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -25,19 +25,25 @@ class ActivationCudnn : public CudaKernelWithoutConfig { cudnnTensorDescriptor_t inputDesc, outputDesc; auto dim = op->getInputs(0)->getDims(); - if (dim.size() != 4) - IT_TODO_HALT(); - int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; + // assume input and output have the same strides. + auto stride = op->getInputs(0)->getStride(); + // CUDNN requires that dim >= 4. + while (dim.size() < 4) + dim.push_back(1); + while (stride.size() < 4) + stride.push_back(1); // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor( - inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + checkCudnnError(cudnnSetTensorNdDescriptor(inputDesc, CUDNN_DATA_FLOAT, + dim.size(), dim.data(), + stride.data())); // get outputs checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor( - outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + checkCudnnError(cudnnSetTensorNdDescriptor(outputDesc, CUDNN_DATA_FLOAT, + dim.size(), dim.data(), + stride.data())); // get op descriptor cudnnActivationDescriptor_t activationDesc;