forked from jiuyuan/InfiniTensor
modify code to pass the cuda_all_reduce test
This commit is contained in:
parent
c04910f118
commit
8cc6af0a83
|
@ -118,15 +118,14 @@ class CUDAStream {
|
|||
CUDAStream(CUDAStream &&) = delete;
|
||||
void operator=(const CUDAStream &) = delete;
|
||||
void operator=(CUDAStream &&) = delete;
|
||||
cudaStream_t getCurrentStream() { return _stream; }
|
||||
static std::unique_ptr<CUDAStream> p_CUDAStream;
|
||||
static void init() { p_CUDAStream.reset(new CUDAStream); }
|
||||
void createStream() { checkCudaError(cudaStreamCreate(&_stream)); }
|
||||
void destroyStream() { checkCudaError(cudaStreamDestroy(_stream)); }
|
||||
static cudaStream_t getCurrentStream() { return _stream; }
|
||||
static void Init() { CUDAStream::_stream = 0; };
|
||||
static void createStream() { checkCudaError(cudaStreamCreate(&_stream)); }
|
||||
static void destroyStream() { checkCudaError(cudaStreamDestroy(_stream)); }
|
||||
|
||||
private:
|
||||
CUDAStream();
|
||||
cudaStream_t _stream;
|
||||
CUDAStream(){};
|
||||
static cudaStream_t _stream;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -30,19 +30,14 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
workspaceSize = 7ll << 30; // 7 GB
|
||||
workspace = alloc(workspaceSize);
|
||||
isCudaGraphCreated = false;
|
||||
CUDAStream::init();
|
||||
CUDAStream::p_CUDAStream->createStream();
|
||||
checkCudnnError(cudnnSetStream(
|
||||
cudnn, CUDAStream::p_CUDAStream->getCurrentStream()));
|
||||
checkCublasError(cublasSetStream(
|
||||
cublas, CUDAStream::p_CUDAStream->getCurrentStream()));
|
||||
CUDAStream::Init();
|
||||
}
|
||||
virtual ~CudaRuntimeObj() {
|
||||
try {
|
||||
if (isCudaGraphCreated) {
|
||||
checkCudaError(cudaGraphExecDestroy(cudaGraphInstance));
|
||||
checkCudaError(cudaGraphDestroy(cudaGraph));
|
||||
CUDAStream::p_CUDAStream->destroyStream();
|
||||
CUDAStream::destroyStream();
|
||||
}
|
||||
dealloc(workspace);
|
||||
checkCudnnError(cudnnDestroy(cudnn));
|
||||
|
|
|
@ -40,21 +40,23 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
|||
|
||||
void CudaRuntimeObj::runWithCudaGraph(const Graph &graph) {
|
||||
if (!isCudaGraphCreated) {
|
||||
checkCudaError(
|
||||
cudaStreamBeginCapture(CUDAStream::p_CUDAStream->getCurrentStream(),
|
||||
cudaStreamCaptureModeGlobal));
|
||||
CUDAStream::createStream();
|
||||
checkCudnnError(cudnnSetStream(cudnn, CUDAStream::getCurrentStream()));
|
||||
checkCublasError(
|
||||
cublasSetStream(cublas, CUDAStream::getCurrentStream()));
|
||||
checkCudaError(cudaStreamBeginCapture(CUDAStream::getCurrentStream(),
|
||||
cudaStreamCaptureModeGlobal));
|
||||
runWithoutSync(graph);
|
||||
checkCudaError(cudaStreamEndCapture(
|
||||
CUDAStream::p_CUDAStream->getCurrentStream(), &cudaGraph));
|
||||
checkCudaError(
|
||||
cudaStreamEndCapture(CUDAStream::getCurrentStream(), &cudaGraph));
|
||||
checkCudaError(
|
||||
cudaGraphInstantiate(&cudaGraphInstance, cudaGraph, NULL, NULL, 0));
|
||||
isCudaGraphCreated = true;
|
||||
} else {
|
||||
checkCudaError(cudaGraphLaunch(
|
||||
cudaGraphInstance, CUDAStream::p_CUDAStream->getCurrentStream()));
|
||||
checkCudaError(
|
||||
cudaGraphLaunch(cudaGraphInstance, CUDAStream::getCurrentStream()));
|
||||
}
|
||||
checkCudaError(
|
||||
cudaStreamSynchronize(CUDAStream::p_CUDAStream->getCurrentStream()));
|
||||
checkCudaError(cudaStreamSynchronize(CUDAStream::getCurrentStream()));
|
||||
}
|
||||
|
||||
void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
|
||||
|
@ -120,4 +122,5 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) {
|
|||
#endif
|
||||
}
|
||||
|
||||
cudaStream_t CUDAStream::_stream = 0;
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
|
||||
namespace infini {
|
||||
std::unique_ptr<CUDAStream> CUDAStream::p_CUDAStream;
|
||||
CUDAStream::CUDAStream() {}
|
||||
|
||||
} // namespace infini
|
|
@ -17,7 +17,7 @@ namespace infini {
|
|||
|
||||
void cudaPrintFloat(float *x, int len) {
|
||||
cudaPrintFloatImpl
|
||||
<<<1, 1, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(x, len);
|
||||
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>(x, len);
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
|
|
|
@ -28,9 +28,8 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
|
|||
ncclComm_t comm =
|
||||
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
|
||||
.getNcclComm();
|
||||
checkNcclError(
|
||||
ncclAllReduce(input, output, count, ncclType, getRedOp(), comm,
|
||||
CUDAStream::p_CUDAStream->getCurrentStream()));
|
||||
checkNcclError(ncclAllReduce(input, output, count, ncclType, getRedOp(),
|
||||
comm, CUDAStream::getCurrentStream()));
|
||||
}
|
||||
|
||||
virtual ncclRedOp_t getRedOp() const = 0;
|
||||
|
|
|
@ -158,13 +158,13 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
|||
dim3 blockDim(BLOCKSIZE, 1);
|
||||
|
||||
_attention_kvcache_kernel_128_1
|
||||
<<<gridDim, blockDim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
|
||||
compMeta, output_O_temp, output_sum_temp);
|
||||
|
||||
_attention_kvcache_kernel_128_2
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
|
||||
0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ void clip_kernel(float *input, float *output, int num, float minValue,
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_clip_kernel
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
input, output, num, minValue, maxValue);
|
||||
}
|
||||
|
||||
|
|
|
@ -130,9 +130,9 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
|||
}
|
||||
}
|
||||
|
||||
#define CASE(OP, T) \
|
||||
_##OP##_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>> \
|
||||
#define CASE(OP, T) \
|
||||
_##OP##_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
|
||||
#define SWITCH_DTYPE(OP, DTYPE) \
|
||||
|
@ -204,11 +204,11 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
|||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
if (dType == 1) {
|
||||
_pow_kernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 3) {
|
||||
_pow_kernel<int8_t>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||
} else if (dType == 10) {
|
||||
int a_size = a0 * a1 * a2 * a3;
|
||||
|
@ -224,7 +224,7 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
|||
b_float[i] = __half2float(((half *)b)[i]);
|
||||
}
|
||||
_pow_kernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
for (int i = 0; i < c_size; ++i) {
|
||||
|
|
|
@ -43,7 +43,7 @@ namespace infini {
|
|||
|
||||
#define CASE(T) \
|
||||
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize, \
|
||||
0, CUDAStream::p_CUDAStream->getCurrentStream()>>>( \
|
||||
0, CUDAStream::getCurrentStream()>>>( \
|
||||
input, output, nDims, outputsize, inputShape, outputShape);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
|
|
|
@ -20,7 +20,7 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
|
|||
int blocksize = 32 * 16;
|
||||
int gridsize = (oSize + blocksize - 1) / blocksize;
|
||||
_extend_kernel
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
in, out, blockSize, blockSizeOuter, oSize);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -45,11 +45,11 @@ void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
|
|||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
if (metaData.indexType == DataType::Int64) {
|
||||
_gather_kernel<T, int64_t>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num);
|
||||
} else {
|
||||
_gather_kernel<T, int>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,25 +41,25 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
|
|||
if (metaData.dataType == DataType::Float32 &&
|
||||
metaData.indexType == DataType::Int64) {
|
||||
_gather_elements_kernel<float, int64_t>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||
metaData, num);
|
||||
} else if (metaData.dataType == DataType::Int32 &&
|
||||
metaData.indexType == DataType::Int64) {
|
||||
_gather_elements_kernel<int, int64_t>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||
num);
|
||||
} else if (metaData.dataType == DataType::Float32 &&
|
||||
metaData.indexType == DataType::Int32) {
|
||||
_gather_elements_kernel<float, int>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
|
||||
metaData, num);
|
||||
} else if (metaData.dataType == DataType::Int32 &&
|
||||
metaData.indexType == DataType::Int32) {
|
||||
_gather_elements_kernel<int, int>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
|
||||
num);
|
||||
} else {
|
||||
|
|
|
@ -344,7 +344,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<float, 1024>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -354,7 +354,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 15) {
|
||||
|
@ -365,7 +365,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 7) {
|
||||
|
@ -376,7 +376,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else {
|
||||
|
@ -387,7 +387,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
}
|
||||
|
@ -401,7 +401,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<float, 1024>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -411,7 +411,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
|
@ -421,7 +421,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -431,7 +431,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -441,7 +441,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<float, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
}
|
||||
}
|
||||
|
@ -454,7 +454,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<half, 1024>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -464,7 +464,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 15) {
|
||||
|
@ -475,7 +475,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else if (dimsize > 7) {
|
||||
|
@ -486,7 +486,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
} else {
|
||||
|
@ -497,7 +497,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
|
||||
bias, biasSize);
|
||||
}
|
||||
|
@ -511,7 +511,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
int BLOCK_DIM = 1024;
|
||||
|
||||
blockLaynormKernel<half, 1024>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -521,7 +521,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
|
@ -531,7 +531,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 16, 64>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -541,7 +541,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 8, 128>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -551,7 +551,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
warpLaynormKernel<half, 4, 256>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ __global__ void _pad_slice_kernel(void *part, void *whole,
|
|||
namespace infini {
|
||||
#define CASE(T) \
|
||||
_pad_slice_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>> \
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(partData, wholeData, metadata, nDims, num, isPad);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
|
|
|
@ -8,7 +8,7 @@ class CopyCuda : public CudaKernelWithoutConfig {
|
|||
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
CUDAStream::p_CUDAStream->getCurrentStream());
|
||||
CUDAStream::getCurrentStream());
|
||||
}
|
||||
};
|
||||
// reshape/flatten/identity all act as copying from input to output.
|
||||
|
|
|
@ -214,7 +214,7 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
|||
IT_ASSERT(nearestMode <
|
||||
sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0]));
|
||||
_resize_kernel_nearest
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num, coordinateMode, nearestMode);
|
||||
}
|
||||
|
||||
|
@ -225,7 +225,7 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
|||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_linear_coeff
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num, coordinateMode);
|
||||
}
|
||||
|
||||
|
@ -236,7 +236,7 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
|
|||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_cubic_coeff
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(in, out, metaData, num, coordinateMode);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -36,9 +36,9 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
}
|
||||
|
||||
|
||||
#define CASE(T) \
|
||||
_rope_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>> \
|
||||
#define CASE(T) \
|
||||
_rope_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
|
|
|
@ -246,37 +246,37 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 64) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 128>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 32) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 64>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 16) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 32>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 4) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 16>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<float, 1024, 4>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -286,7 +286,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 32, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
|
@ -296,7 +296,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 16, 64, 2>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -306,7 +306,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 8, 128, 2>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -316,7 +316,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<float, 4, 256, 2>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
}
|
||||
}
|
||||
|
@ -328,37 +328,37 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 64) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 128>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 32) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 64>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 16) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 32>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024 * 4) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 16>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 1024) {
|
||||
|
||||
int BLOCK_DIM = 1024;
|
||||
_blockSoftmaxKernel<half, 1024, 4>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 31) {
|
||||
int BLOCK_DIM_x = 32;
|
||||
|
@ -368,7 +368,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 32, 32, 32>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 15) {
|
||||
int BLOCK_DIM_x = 16;
|
||||
|
@ -378,7 +378,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 16, 64, 2>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else if (dimsize > 7) {
|
||||
int BLOCK_DIM_x = 8;
|
||||
|
@ -388,7 +388,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 8, 128, 2>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
} else {
|
||||
int BLOCK_DIM_x = 4;
|
||||
|
@ -398,7 +398,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
|
|||
dim3 grid_dim(num_block_x, 1, 1);
|
||||
|
||||
_warpSoftmaxKernel<half, 4, 256, 2>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, size, dimsize, stride);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,10 +68,10 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
|
|||
_op->getInputs(1 - i)->getRawDataPtr<void *>();
|
||||
auto outData =
|
||||
_op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
cudaMemcpyAsync(
|
||||
outData, inData, _op->getInputs(1 - i)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
CUDAStream::p_CUDAStream->getCurrentStream());
|
||||
cudaMemcpyAsync(outData, inData,
|
||||
_op->getInputs(1 - i)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
CUDAStream::getCurrentStream());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
|
|||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
_split_concat_kernel
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(eleMeta, compMeta, dim, nDims, isSplit);
|
||||
}
|
||||
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
||||
|
@ -79,7 +79,7 @@ void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
|
|||
dim3 gridSize(gridDimX, batchSize);
|
||||
|
||||
_split_concat_kernel
|
||||
<<<gridSize, blockSize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(eleMeta, compMeta, dim, nDims, isSplit);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,9 +22,9 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
|
|||
((T *)output)[outputIdx] = ((T *)input)[inputIdx];
|
||||
}
|
||||
}
|
||||
#define CASE(T) \
|
||||
_transpose_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>\
|
||||
#define CASE(T) \
|
||||
_transpose_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(input, output, nDims, size, strides, outputShape);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
|
|
|
@ -149,10 +149,10 @@ template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_softmax_kernel1<T>
|
||||
<<<1, 1, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
_softmax_kernel2<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void relu_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -160,7 +160,7 @@ template <typename T> void relu_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_relu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -168,7 +168,7 @@ template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sigmoid_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T>
|
||||
|
@ -177,7 +177,7 @@ void hard_sigmoid_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_sigmoid_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -185,7 +185,7 @@ template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_hard_swish_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -193,7 +193,7 @@ template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_tanh_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void abs_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -201,7 +201,7 @@ template <typename T> void abs_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_abs_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -209,7 +209,7 @@ template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sqrt_kernel
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
((T *)input, (T *)output, num);
|
||||
}
|
||||
|
||||
|
@ -218,7 +218,7 @@ template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_gelu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
|
@ -227,7 +227,7 @@ template <typename T> void silu_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_silu_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
|
@ -236,7 +236,7 @@ template <typename T> void erf_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_erf_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
template <typename T> void neg_kernel(T *input, T *output, size_t num) {
|
||||
|
@ -244,7 +244,7 @@ template <typename T> void neg_kernel(T *input, T *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_neg_kernel<T>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
|
@ -344,7 +344,7 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
|
|||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_cast_kernel<INPUT, OUTPUT>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input, output, num);
|
||||
}
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ void whereKernel(const float *inputX, const float *inputY,
|
|||
}
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<float>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ void whereKernel(const half *inputX, const half *inputY,
|
|||
}
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_whereKernel<half>
|
||||
<<<gridsize, blocksize, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||
inputX, inputY, condition, output, nDims, outputsize, inputXShape,
|
||||
inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue