cudnn activations support ND-Tensor (#116)

* refine TensorObj::getStride

* ActivationCudnn supports ND-Tensor
This commit is contained in:
constroy Li 2023-08-22 14:21:59 +08:00 committed by GitHub
parent 48847958d0
commit 384407421b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 16 deletions

View File

@ -37,7 +37,7 @@ class TensorObj : public TensorBaseObj {
Shape getDims() const { return shape; } Shape getDims() const { return shape; }
size_t getRank() const { return shape.size(); } size_t getRank() const { return shape.size(); }
vector<size_t> getStride() const; Shape getStride() const;
size_t getOffset(const vector<int> &ds) const; size_t getOffset(const vector<int> &ds) const;
void dataMalloc(); void dataMalloc();
UidBaseType getFuid() const { return fuid; } UidBaseType getFuid() const { return fuid; }

View File

@ -51,15 +51,14 @@ size_t TensorObj::getOffset(const vector<int> &pos) const {
return idx; return idx;
} }
vector<size_t> TensorObj::getStride() const { Shape TensorObj::getStride() const {
vector<size_t> ret; Shape stride(getRank());
size_t stride = 1; ShapeElem p = 1;
for (int i = shape.size() - 1; i >= 1; i--) { for (auto i = getRank(); i > 0; --i) {
ret.emplace(ret.begin(), stride); stride[i - 1] = p;
stride *= shape.at(i); p = p * shape[i - 1];
} }
ret.emplace(ret.begin(), stride); return stride;
return ret;
} }
void TensorObj::printData() const { void TensorObj::printData() const {

View File

@ -25,19 +25,25 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
cudnnTensorDescriptor_t inputDesc, outputDesc; cudnnTensorDescriptor_t inputDesc, outputDesc;
auto dim = op->getInputs(0)->getDims(); auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4) // assume input and output have the same strides.
IT_TODO_HALT(); auto stride = op->getInputs(0)->getStride();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3]; // CUDNN requires that dim >= 4.
while (dim.size() < 4)
dim.push_back(1);
while (stride.size() < 4)
stride.push_back(1);
// get inputs // get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor( checkCudnnError(cudnnSetTensorNdDescriptor(inputDesc, CUDNN_DATA_FLOAT,
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); dim.size(), dim.data(),
stride.data()));
// get outputs // get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor( checkCudnnError(cudnnSetTensorNdDescriptor(outputDesc, CUDNN_DATA_FLOAT,
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); dim.size(), dim.data(),
stride.data()));
// get op descriptor // get op descriptor
cudnnActivationDescriptor_t activationDesc; cudnnActivationDescriptor_t activationDesc;