forked from jiuyuan/InfiniTensor
memcopy instead of special kernel
This commit is contained in:
parent
73e3f1fc6f
commit
7146294baa
|
@ -7,5 +7,5 @@ namespace infini {
|
|||
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape);
|
||||
void transposeSpecial_kernel(int dType, void *input, void *output, int size);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -40,7 +40,10 @@ class TransposeCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
//----------------
|
||||
if (condition) {
|
||||
transposeSpecial_kernel(dType, inputData, outputData, size);
|
||||
cudaMemcpyAsync(outputData, inputData, op->getInputs(0)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
CUDAStream::getCurrentStream());
|
||||
|
||||
} else {
|
||||
const auto &perm = op->getPermute();
|
||||
|
||||
|
|
|
@ -68,47 +68,7 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
|
|||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
template <class T>
|
||||
__global__ void _transposeSpecial_kernel(void *input, void *output, int size) {
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < size) {
|
||||
((T *)output)[outputIdx] = ((T *)input)[outputIdx];
|
||||
}
|
||||
}
|
||||
#define CASESpecial(T) \
|
||||
_transposeSpecial_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||
input, output, size);
|
||||
|
||||
#define SWITCHSpecial_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASESpecial(1) break; \
|
||||
case 2: \
|
||||
CASESpecial(2) break; \
|
||||
case 3: \
|
||||
CASESpecial(3) break; \
|
||||
case 4: \
|
||||
CASESpecial(4) break; \
|
||||
case 5: \
|
||||
CASESpecial(5) break; \
|
||||
case 6: \
|
||||
CASESpecial(6) break; \
|
||||
case 7: \
|
||||
CASESpecial(7) break; \
|
||||
case 10: \
|
||||
CASESpecial(10) break; \
|
||||
case 11: \
|
||||
CASESpecial(11) break; \
|
||||
case 12: \
|
||||
CASESpecial(12) break; \
|
||||
case 13: \
|
||||
CASESpecial(13) break; \
|
||||
case 16: \
|
||||
CASESpecial(16) break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
namespace infini {
|
||||
void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape) {
|
||||
|
@ -116,10 +76,5 @@ void transpose_kernel(int dType, void *input, void *output, int nDims, int size,
|
|||
int gridsize = (size + block_work_size() - 1) / block_work_size();
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
void transposeSpecial_kernel(int dType, void *input, void *output, int size) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (size + block_work_size() - 1) / block_work_size();
|
||||
SWITCHSpecial_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue