diff --git a/CMakeLists.txt b/CMakeLists.txt index 1101a8c2..aa167ee2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,7 +212,7 @@ if(USE_CUDA) ${CMAKE_CXX_COMPILER} CACHE STRING "Set cuda host compiler path") # CMP0104 requires CUDA_ARCHITECTURES - set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80") + set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES 80) enable_language(CUDA) find_package(CUDAToolkit) # For nvrtc and cuda driver target_link_libraries(InfiniTensor cudnn CUDA::curand CUDA::cublas CUDA::nvrtc CUDA::cudart CUDA::cuda_driver) diff --git a/include/cuda/cuda_transpose.h b/include/cuda/cuda_transpose.h index b168cf0e..89d080ed 100644 --- a/include/cuda/cuda_transpose.h +++ b/include/cuda/cuda_transpose.h @@ -5,7 +5,7 @@ namespace infini { -void transpose_kernel(float *input, float *output, int nDims, int size, +void transpose_kernel(int dType, void *input, void *output, int nDims, int size, SmallArray strides, SmallArray outputShape); }; // namespace infini diff --git a/src/kernels/cuda/transpose.cc b/src/kernels/cuda/transpose.cc index e6b3a84f..b22ee3dd 100644 --- a/src/kernels/cuda/transpose.cc +++ b/src/kernels/cuda/transpose.cc @@ -38,8 +38,9 @@ class TransposeCuda : public CudaKernelWithoutConfig { outputDims.data[i] = outputShape[i]; } - transpose_kernel((float *)inputData, (float *)outputData, nDims, size, - strides, outputDims); + const int dType = op->getDType().getIndex(); + transpose_kernel(dType, inputData, outputData, nDims, size, strides, + outputDims); } }; @@ -82,9 +83,9 @@ class DepthToSpaceCuda : public CudaKernelWithoutConfig { for (int i = 0; i < nDims; ++i) { outputDims.data[i] = transpose[i]; } - - transpose_kernel((float *)inputData, (float *)outputData, nDims, size, - strides, outputDims); + const int dType = op->getDType().getIndex(); + transpose_kernel(dType, inputData, outputData, nDims, size, strides, + outputDims); } }; diff --git a/src/kernels/cuda/transpose.cu b/src/kernels/cuda/transpose.cu index f753217c..917afde3 100644 --- a/src/kernels/cuda/transpose.cu +++ b/src/kernels/cuda/transpose.cu @@ -1,12 +1,14 @@ #include "core/common.h" #include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" #include "utils/small_array.h" constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _transpose_kernel(float *input, float *output, int nDims, +template +__global__ void _transpose_kernel(void *input, void *output, int nDims, int size, infini::SmallArray strides, infini::SmallArray outputShape) { int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; @@ -17,21 +19,61 @@ __global__ void _transpose_kernel(float *input, float *output, int nDims, inputIdx += v % outputShape.data[i] * strides.data[i]; v /= outputShape.data[i]; } -#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) - output[outputIdx] = __ldg(input + inputIdx); -#else - output[outputIdx] = input[inputIdx]; -#endif + ((T *)output)[outputIdx] = ((T *)input)[inputIdx]; } } +#define CASE(T) \ + _transpose_kernel::t><<>>( \ + input, output, nDims, size, strides, outputShape); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } namespace infini { -void transpose_kernel(float *input, float *output, int nDims, int size, +void transpose_kernel(int dType, void *input, void *output, int nDims, int size, SmallArray strides, SmallArray outputShape) { int blocksize = block_work_size(); int gridsize = (size + block_work_size() - 1) / block_work_size(); - _transpose_kernel<<>>(input, output, nDims, size, - strides, outputShape); + SWITCH_DTYPE(dType) } } // namespace infini