forked from jiuyuan/InfiniTensor
feat: support transpose fp16
This commit is contained in:
parent
4b02de7e17
commit
d5e775397d
|
@ -212,7 +212,7 @@ if(USE_CUDA)
|
|||
${CMAKE_CXX_COMPILER}
|
||||
CACHE STRING "Set cuda host compiler path")
|
||||
# CMP0104 requires CUDA_ARCHITECTURES
|
||||
set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80")
|
||||
set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES 80)
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit) # For nvrtc and cuda driver
|
||||
target_link_libraries(InfiniTensor cudnn CUDA::curand CUDA::cublas CUDA::nvrtc CUDA::cudart CUDA::cuda_driver)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
void transpose_kernel(float *input, float *output, int nDims, int size,
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -38,8 +38,9 @@ class TransposeCuda : public CudaKernelWithoutConfig {
|
|||
outputDims.data[i] = outputShape[i];
|
||||
}
|
||||
|
||||
transpose_kernel((float *)inputData, (float *)outputData, nDims, size,
|
||||
strides, outputDims);
|
||||
const int dType = op->getDType().getIndex();
|
||||
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
|
||||
outputDims);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -82,9 +83,9 @@ class DepthToSpaceCuda : public CudaKernelWithoutConfig {
|
|||
for (int i = 0; i < nDims; ++i) {
|
||||
outputDims.data[i] = transpose[i];
|
||||
}
|
||||
|
||||
transpose_kernel((float *)inputData, (float *)outputData, nDims, size,
|
||||
strides, outputDims);
|
||||
const int dType = op->getDType().getIndex();
|
||||
transpose_kernel(dType, inputData, outputData, nDims, size, strides,
|
||||
outputDims);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _transpose_kernel(float *input, float *output, int nDims,
|
||||
template <class T>
|
||||
__global__ void _transpose_kernel(void *input, void *output, int nDims,
|
||||
int size, infini::SmallArray strides,
|
||||
infini::SmallArray outputShape) {
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
@ -17,21 +19,61 @@ __global__ void _transpose_kernel(float *input, float *output, int nDims,
|
|||
inputIdx += v % outputShape.data[i] * strides.data[i];
|
||||
v /= outputShape.data[i];
|
||||
}
|
||||
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
|
||||
output[outputIdx] = __ldg(input + inputIdx);
|
||||
#else
|
||||
output[outputIdx] = input[inputIdx];
|
||||
#endif
|
||||
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
|
||||
}
|
||||
}
|
||||
#define CASE(T) \
|
||||
_transpose_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
input, output, nDims, size, strides, outputShape);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void transpose_kernel(float *input, float *output, int nDims, int size,
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (size + block_work_size() - 1) / block_work_size();
|
||||
_transpose_kernel<<<gridsize, blocksize>>>(input, output, nDims, size,
|
||||
strides, outputShape);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue