forked from jiuyuan/InfiniTensor
feat: support sliceOp fp16
This commit is contained in:
parent
d5e775397d
commit
ee4ecd27e2
|
@ -1,6 +1,7 @@
|
||||||
#include "core/data_type.h"
|
#include "core/data_type.h"
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include "cuda/cuda_pad_slice.h"
|
#include "cuda/cuda_pad_slice.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
|
||||||
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
__device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
||||||
TransMetaData metaData,
|
TransMetaData metaData,
|
||||||
|
@ -21,39 +22,83 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData,
|
__global__ void _pad_slice_kernel(void *part, void *whole,
|
||||||
int nDims, int num, bool isPad) {
|
TransMetaData metaData, int nDims, int num,
|
||||||
|
bool isPad) {
|
||||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (tid >= num)
|
if (tid >= num) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int stride = blockDim.x * gridDim.x;
|
int stride = blockDim.x * gridDim.x;
|
||||||
while (tid < num) {
|
while (tid < num) {
|
||||||
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
|
int offset = WholeTensorOffset2PartTensorOffset(tid, metaData, nDims);
|
||||||
if (isPad)
|
if (isPad) {
|
||||||
if (offset < 0)
|
if (offset < 0) {
|
||||||
whole[tid] = 0;
|
((T *)whole)[tid] = 0;
|
||||||
else
|
} else {
|
||||||
whole[tid] = part[offset];
|
((T *)whole)[tid] = ((T *)part)[offset];
|
||||||
else
|
}
|
||||||
part[offset] = whole[tid];
|
} else {
|
||||||
|
((T *)part)[offset] = ((T *)whole)[tid];
|
||||||
|
}
|
||||||
tid += stride;
|
tid += stride;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
#define CASE(T) \
|
||||||
|
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \
|
||||||
|
partData, wholeData, metadata, nDims, num, isPad);
|
||||||
|
|
||||||
|
#define SWITCH_DTYPE(DTYPE) \
|
||||||
|
switch (DTYPE) { \
|
||||||
|
case 1: \
|
||||||
|
CASE(1) \
|
||||||
|
break; \
|
||||||
|
case 2: \
|
||||||
|
CASE(2) \
|
||||||
|
break; \
|
||||||
|
case 3: \
|
||||||
|
CASE(3) \
|
||||||
|
break; \
|
||||||
|
case 4: \
|
||||||
|
CASE(4) \
|
||||||
|
break; \
|
||||||
|
case 5: \
|
||||||
|
CASE(5) \
|
||||||
|
break; \
|
||||||
|
case 6: \
|
||||||
|
CASE(6) \
|
||||||
|
break; \
|
||||||
|
case 7: \
|
||||||
|
CASE(7) \
|
||||||
|
break; \
|
||||||
|
case 10: \
|
||||||
|
CASE(10) \
|
||||||
|
break; \
|
||||||
|
case 11: \
|
||||||
|
CASE(11) \
|
||||||
|
break; \
|
||||||
|
case 12: \
|
||||||
|
CASE(12) \
|
||||||
|
break; \
|
||||||
|
case 13: \
|
||||||
|
CASE(13) \
|
||||||
|
break; \
|
||||||
|
case 16: \
|
||||||
|
CASE(16) \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
IT_TODO_HALT(); \
|
||||||
|
}
|
||||||
|
|
||||||
void pad_slice_kernel(void *partData, void *wholeData,
|
void pad_slice_kernel(void *partData, void *wholeData,
|
||||||
const TransMetaData &metadata, int nDims, int num,
|
const TransMetaData &metadata, int nDims, int num,
|
||||||
bool isPad) {
|
bool isPad) {
|
||||||
int blockSize = 32 * 16;
|
int blockSize = 32 * 16;
|
||||||
int gridSize = (num + blockSize - 1) / blockSize;
|
int gridSize = (num + blockSize - 1) / blockSize;
|
||||||
if (metadata.DType == DataType::Int64.getIndex()) {
|
int dType = metadata.DType;
|
||||||
_pad_slice_kernel<int64_t>
|
SWITCH_DTYPE(dType)
|
||||||
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
|
|
||||||
metadata, nDims, num, isPad);
|
|
||||||
} else if (metadata.DType == DataType::Float32.getIndex()) {
|
|
||||||
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
|
|
||||||
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue