forked from jiuyuan/InfiniTensor
feat: support sliceOp fp16
This commit is contained in:
parent
d5e775397d
commit
ee4ecd27e2
|
@ -1,6 +1,7 @@
|
|||
#include "core/data_type.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_pad_slice.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
|
||||
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
||||
TransMetaData metaData,
|
||||
|
@ -21,39 +22,83 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData,
|
||||
int nDims, int num, bool isPad) {
|
||||
__global__ void _pad_slice_kernel(void *part, void *whole,
|
||||
TransMetaData metaData, int nDims, int num,
|
||||
bool isPad) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (tid >= num)
|
||||
if (tid >= num) {
|
||||
return;
|
||||
}
|
||||
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
while (tid < num) {
|
||||
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
|
||||
if (isPad)
|
||||
if (offset < 0)
|
||||
whole[tid] = 0;
|
||||
else
|
||||
whole[tid] = part[offset];
|
||||
else
|
||||
part[offset] = whole[tid];
|
||||
if (isPad) {
|
||||
if (offset < 0) {
|
||||
((T *)whole)[tid] = 0;
|
||||
} else {
|
||||
((T *)whole)[tid] = ((T *)part)[offset];
|
||||
}
|
||||
} else {
|
||||
((T *)part)[offset] = ((T *)whole)[tid];
|
||||
}
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
#define CASE(T) \
|
||||
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \
|
||||
partData, wholeData, metadata, nDims, num, isPad);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
void pad_slice_kernel(void *partData, void *wholeData,
|
||||
const TransMetaData &metadata, int nDims, int num,
|
||||
bool isPad) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metadata.DType == DataType::Int64.getIndex()) {
|
||||
_pad_slice_kernel<int64_t>
|
||||
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
|
||||
metadata, nDims, num, isPad);
|
||||
} else if (metadata.DType == DataType::Float32.getIndex()) {
|
||||
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
|
||||
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
|
||||
}
|
||||
int dType = metadata.DType;
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue