feat: support sliceOp fp16

This commit is contained in:
zhangyunze 2023-12-13 16:55:16 +08:00
parent d5e775397d
commit ee4ecd27e2
1 changed files with 63 additions and 18 deletions

View File

@ -1,6 +1,7 @@
#include "core/data_type.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_pad_slice.h"
#include "cuda/cuda_utility.h"
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
TransMetaData metaData,
@ -21,39 +22,83 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
}
template <typename T>
__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData,
int nDims, int num, bool isPad) {
__global__ void _pad_slice_kernel(void *part, void *whole,
TransMetaData metaData, int nDims, int num,
bool isPad) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num)
if (tid >= num) {
return;
}
int stride = blockDim.x * gridDim.x;
while (tid < num) {
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
if (isPad)
if (offset < 0)
whole[tid] = 0;
else
whole[tid] = part[offset];
else
part[offset] = whole[tid];
if (isPad) {
if (offset < 0) {
((T *)whole)[tid] = 0;
} else {
((T *)whole)[tid] = ((T *)part)[offset];
}
} else {
((T *)part)[offset] = ((T *)whole)[tid];
}
tid += stride;
}
}
namespace infini {
#define CASE(T) \
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \
partData, wholeData, metadata, nDims, num, isPad);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
void pad_slice_kernel(void *partData, void *wholeData,
const TransMetaData &metadata, int nDims, int num,
bool isPad) {
int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize;
if (metadata.DType == DataType::Int64.getIndex()) {
_pad_slice_kernel<int64_t>
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
metadata, nDims, num, isPad);
} else if (metadata.DType == DataType::Float32.getIndex()) {
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
}
int dType = metadata.DType;
SWITCH_DTYPE(dType)
}
} // namespace infini