memcopy instead of special kernel

This commit is contained in:
xgqdut2016 2024-05-06 14:49:39 +08:00
parent 73e3f1fc6f
commit 7146294baa
3 changed files with 5 additions and 47 deletions

View File

@ -7,5 +7,5 @@ namespace infini {
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
SmallArray strides, SmallArray outputShape);
void transposeSpecial_kernel(int dType, void *input, void *output, int size);
}; // namespace infini

View File

@ -40,7 +40,10 @@ class TransposeCuda : public CudaKernelWithoutConfig {
}
//----------------
if (condition) {
transposeSpecial_kernel(dType, inputData, outputData, size);
cudaMemcpyAsync(outputData, inputData, op->getInputs(0)->getBytes(),
cudaMemcpyDeviceToDevice,
CUDAStream::getCurrentStream());
} else {
const auto &perm = op->getPermute();

View File

@ -68,47 +68,7 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
default: \
IT_TODO_HALT(); \
}
template <class T>
__global__ void _transposeSpecial_kernel(void *input, void *output, int size) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < size) {
((T *)output)[outputIdx] = ((T *)input)[outputIdx];
}
}
#define CASESpecial(T) \
_transposeSpecial_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
input, output, size);
#define SWITCHSpecial_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASESpecial(1) break; \
case 2: \
CASESpecial(2) break; \
case 3: \
CASESpecial(3) break; \
case 4: \
CASESpecial(4) break; \
case 5: \
CASESpecial(5) break; \
case 6: \
CASESpecial(6) break; \
case 7: \
CASESpecial(7) break; \
case 10: \
CASESpecial(10) break; \
case 11: \
CASESpecial(11) break; \
case 12: \
CASESpecial(12) break; \
case 13: \
CASESpecial(13) break; \
case 16: \
CASESpecial(16) break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
SmallArray strides, SmallArray outputShape) {
@ -116,10 +76,5 @@ void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
int gridsize = (size + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE(dType)
}
void transposeSpecial_kernel(int dType, void *input, void *output, int size) {
int blocksize = block_work_size();
int gridsize = (size + block_work_size() - 1) / block_work_size();
SWITCHSpecial_DTYPE(dType)
}
} // namespace infini