cudnn activations support ND-Tensor (#116)

* refine TensorObj::getStride

* ActivationCudnn supports ND-Tensor
This commit is contained in:
constroy Li 2023-08-22 14:21:59 +08:00 committed by GitHub
parent 48847958d0
commit 384407421b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 16 deletions

View File

@ -37,7 +37,7 @@ class TensorObj : public TensorBaseObj {
Shape getDims() const { return shape; }
size_t getRank() const { return shape.size(); }
vector<size_t> getStride() const;
Shape getStride() const;
size_t getOffset(const vector<int> &ds) const;
void dataMalloc();
UidBaseType getFuid() const { return fuid; }

View File

@ -51,15 +51,14 @@ size_t TensorObj::getOffset(const vector<int> &pos) const {
return idx;
}
vector<size_t> TensorObj::getStride() const {
vector<size_t> ret;
size_t stride = 1;
for (int i = shape.size() - 1; i >= 1; i--) {
ret.emplace(ret.begin(), stride);
stride *= shape.at(i);
Shape TensorObj::getStride() const {
Shape stride(getRank());
ShapeElem p = 1;
for (auto i = getRank(); i > 0; --i) {
stride[i - 1] = p;
p = p * shape[i - 1];
}
ret.emplace(ret.begin(), stride);
return ret;
return stride;
}
void TensorObj::printData() const {

View File

@ -25,19 +25,25 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
cudnnTensorDescriptor_t inputDesc, outputDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
// assume input and output have the same strides.
auto stride = op->getInputs(0)->getStride();
// CUDNN requires that dim >= 4.
while (dim.size() < 4)
dim.push_back(1);
while (stride.size() < 4)
stride.push_back(1);
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
checkCudnnError(cudnnSetTensorNdDescriptor(inputDesc, CUDNN_DATA_FLOAT,
dim.size(), dim.data(),
stride.data()));
// get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
checkCudnnError(cudnnSetTensorNdDescriptor(outputDesc, CUDNN_DATA_FLOAT,
dim.size(), dim.data(),
stride.data()));
// get op descriptor
cudnnActivationDescriptor_t activationDesc;