diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index e77632bf..a024d4f3 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -118,15 +118,14 @@ class CUDAStream { CUDAStream(CUDAStream &&) = delete; void operator=(const CUDAStream &) = delete; void operator=(CUDAStream &&) = delete; - cudaStream_t getCurrentStream() { return _stream; } - static std::unique_ptr p_CUDAStream; - static void init() { p_CUDAStream.reset(new CUDAStream); } - void createStream() { checkCudaError(cudaStreamCreate(&_stream)); } - void destroyStream() { checkCudaError(cudaStreamDestroy(_stream)); } + static cudaStream_t getCurrentStream() { return _stream; } + static void Init() { CUDAStream::_stream = 0; }; + static void createStream() { checkCudaError(cudaStreamCreate(&_stream)); } + static void destroyStream() { checkCudaError(cudaStreamDestroy(_stream)); } private: - CUDAStream(); - cudaStream_t _stream; + CUDAStream(){}; + static cudaStream_t _stream; }; } // namespace infini diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index 5941d476..7d48c019 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -30,19 +30,14 @@ class CudaRuntimeObj : public RuntimeObj { workspaceSize = 7ll << 30; // 7 GB workspace = alloc(workspaceSize); isCudaGraphCreated = false; - CUDAStream::init(); - CUDAStream::p_CUDAStream->createStream(); - checkCudnnError(cudnnSetStream( - cudnn, CUDAStream::p_CUDAStream->getCurrentStream())); - checkCublasError(cublasSetStream( - cublas, CUDAStream::p_CUDAStream->getCurrentStream())); + CUDAStream::Init(); } virtual ~CudaRuntimeObj() { try { if (isCudaGraphCreated) { checkCudaError(cudaGraphExecDestroy(cudaGraphInstance)); checkCudaError(cudaGraphDestroy(cudaGraph)); - CUDAStream::p_CUDAStream->destroyStream(); + CUDAStream::destroyStream(); } dealloc(workspace); checkCudnnError(cudnnDestroy(cudnn)); diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 944e3c4f..9bab5018 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -40,21 +40,23 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { void CudaRuntimeObj::runWithCudaGraph(const Graph &graph) { if (!isCudaGraphCreated) { - checkCudaError( - cudaStreamBeginCapture(CUDAStream::p_CUDAStream->getCurrentStream(), - cudaStreamCaptureModeGlobal)); + CUDAStream::createStream(); + checkCudnnError(cudnnSetStream(cudnn, CUDAStream::getCurrentStream())); + checkCublasError( + cublasSetStream(cublas, CUDAStream::getCurrentStream())); + checkCudaError(cudaStreamBeginCapture(CUDAStream::getCurrentStream(), + cudaStreamCaptureModeGlobal)); runWithoutSync(graph); - checkCudaError(cudaStreamEndCapture( - CUDAStream::p_CUDAStream->getCurrentStream(), &cudaGraph)); + checkCudaError( + cudaStreamEndCapture(CUDAStream::getCurrentStream(), &cudaGraph)); checkCudaError( cudaGraphInstantiate(&cudaGraphInstance, cudaGraph, NULL, NULL, 0)); isCudaGraphCreated = true; } else { - checkCudaError(cudaGraphLaunch( - cudaGraphInstance, CUDAStream::p_CUDAStream->getCurrentStream())); + checkCudaError( + cudaGraphLaunch(cudaGraphInstance, CUDAStream::getCurrentStream())); } - checkCudaError( - cudaStreamSynchronize(CUDAStream::p_CUDAStream->getCurrentStream())); + checkCudaError(cudaStreamSynchronize(CUDAStream::getCurrentStream())); } void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { @@ -120,4 +122,5 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) { #endif } +cudaStream_t CUDAStream::_stream = 0; } // namespace infini diff --git a/src/cuda/cuda_stream.cc b/src/cuda/cuda_stream.cc deleted file mode 100644 index 79e22838..00000000 --- a/src/cuda/cuda_stream.cc +++ /dev/null @@ -1,7 +0,0 @@ -#include "cuda/cuda_common.h" - -namespace infini { -std::unique_ptr CUDAStream::p_CUDAStream; -CUDAStream::CUDAStream() {} - -} // namespace infini diff --git a/src/cuda/cuda_utility.cu b/src/cuda/cuda_utility.cu index f67f970d..83cee26c 100644 --- a/src/cuda/cuda_utility.cu +++ b/src/cuda/cuda_utility.cu @@ -17,7 +17,7 @@ namespace infini { void cudaPrintFloat(float *x, int len) { cudaPrintFloatImpl - <<<1, 1, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>>(x, len); + <<<1, 1, 0, CUDAStream::getCurrentStream()>>>(x, len); cudaDeviceSynchronize(); } diff --git a/src/kernels/cuda/all_reduce.cc b/src/kernels/cuda/all_reduce.cc index 194462d4..e77c98b6 100644 --- a/src/kernels/cuda/all_reduce.cc +++ b/src/kernels/cuda/all_reduce.cc @@ -28,9 +28,8 @@ class AllReduceNCCL : public CudaKernelWithoutConfig { ncclComm_t comm = dynamic_cast(context->getCommunicator()) .getNcclComm(); - checkNcclError( - ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, - CUDAStream::p_CUDAStream->getCurrentStream())); + checkNcclError(ncclAllReduce(input, output, count, ncclType, getRedOp(), + comm, CUDAStream::getCurrentStream())); } virtual ncclRedOp_t getRedOp() const = 0; diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index 39dd4393..476220db 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -158,13 +158,13 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, dim3 blockDim(BLOCKSIZE, 1); _attention_kvcache_kernel_128_1 - <<getCurrentStream()>>> + <<>> (input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, compMeta, output_O_temp, output_sum_temp); _attention_kvcache_kernel_128_2 <<getCurrentStream()>>> + 0, CUDAStream::getCurrentStream()>>> (position_id, output_matmul, compMeta, output_O_temp, output_sum_temp); } diff --git a/src/kernels/cuda/clip.cu b/src/kernels/cuda/clip.cu index 9fbe80c4..85b096b0 100644 --- a/src/kernels/cuda/clip.cu +++ b/src/kernels/cuda/clip.cu @@ -26,7 +26,7 @@ void clip_kernel(float *input, float *output, int num, float minValue, int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _clip_kernel - <<getCurrentStream()>>>( + <<>>( input, output, num, minValue, maxValue); } diff --git a/src/kernels/cuda/element_wise.cu b/src/kernels/cuda/element_wise.cu index e4c75e0c..e1b68699 100644 --- a/src/kernels/cuda/element_wise.cu +++ b/src/kernels/cuda/element_wise.cu @@ -130,9 +130,9 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2, } } -#define CASE(OP, T) \ - _##OP##_kernel::t> \ - <<getCurrentStream()>>> \ +#define CASE(OP, T) \ + _##OP##_kernel::t> \ + <<>> \ (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); #define SWITCH_DTYPE(OP, DTYPE) \ @@ -204,11 +204,11 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, int gridsize = (num + block_work_size() - 1) / block_work_size(); if (dType == 1) { _pow_kernel - <<getCurrentStream()>>> + <<>> (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); } else if (dType == 3) { _pow_kernel - <<getCurrentStream()>>> + <<>> (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); } else if (dType == 10) { int a_size = a0 * a1 * a2 * a3; @@ -224,7 +224,7 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, b_float[i] = __half2float(((half *)b)[i]); } _pow_kernel - <<getCurrentStream()>>> + <<>> (a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); for (int i = 0; i < c_size; ++i) { diff --git a/src/kernels/cuda/expand.cu b/src/kernels/cuda/expand.cu index a4143a8f..5e22be44 100644 --- a/src/kernels/cuda/expand.cu +++ b/src/kernels/cuda/expand.cu @@ -43,7 +43,7 @@ namespace infini { #define CASE(T) \ _expandKernel::t><<getCurrentStream()>>>( \ + 0, CUDAStream::getCurrentStream()>>>( \ input, output, nDims, outputsize, inputShape, outputShape); #define SWITCH_DTYPE(DTYPE) \ diff --git a/src/kernels/cuda/extend.cu b/src/kernels/cuda/extend.cu index 8e90e021..3fce9922 100644 --- a/src/kernels/cuda/extend.cu +++ b/src/kernels/cuda/extend.cu @@ -20,7 +20,7 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter, int blocksize = 32 * 16; int gridsize = (oSize + blocksize - 1) / blocksize; _extend_kernel - <<getCurrentStream()>>>( + <<>>( in, out, blockSize, blockSizeOuter, oSize); } } // namespace infini diff --git a/src/kernels/cuda/gather.cu b/src/kernels/cuda/gather.cu index 207d3e7c..7b2d9dbf 100644 --- a/src/kernels/cuda/gather.cu +++ b/src/kernels/cuda/gather.cu @@ -45,11 +45,11 @@ void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) { int gridSize = (num + blockSize - 1) / blockSize; if (metaData.indexType == DataType::Int64) { _gather_kernel - <<getCurrentStream()>>> + <<>> (in, out, metaData, num); } else { _gather_kernel - <<getCurrentStream()>>> + <<>> (in, out, metaData, num); } } diff --git a/src/kernels/cuda/gather_elements.cu b/src/kernels/cuda/gather_elements.cu index 506ada30..545820b4 100644 --- a/src/kernels/cuda/gather_elements.cu +++ b/src/kernels/cuda/gather_elements.cu @@ -41,25 +41,25 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, if (metaData.dataType == DataType::Float32 && metaData.indexType == DataType::Int64) { _gather_elements_kernel - <<getCurrentStream()>>>( + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else if (metaData.dataType == DataType::Int32 && metaData.indexType == DataType::Int64) { _gather_elements_kernel - <<getCurrentStream()>>>( + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else if (metaData.dataType == DataType::Float32 && metaData.indexType == DataType::Int32) { _gather_elements_kernel - <<getCurrentStream()>>>( + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else if (metaData.dataType == DataType::Int32 && metaData.indexType == DataType::Int32) { _gather_elements_kernel - <<getCurrentStream()>>>( + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else { diff --git a/src/kernels/cuda/layer_norm.cu b/src/kernels/cuda/layer_norm.cu index 841757ea..b3d74c77 100644 --- a/src/kernels/cuda/layer_norm.cu +++ b/src/kernels/cuda/layer_norm.cu @@ -344,7 +344,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, int BLOCK_DIM = 1024; blockLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -354,7 +354,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 15) { @@ -365,7 +365,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 7) { @@ -376,7 +376,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else { @@ -387,7 +387,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } @@ -401,7 +401,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, int BLOCK_DIM = 1024; blockLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -411,7 +411,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -421,7 +421,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -431,7 +431,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else { int BLOCK_DIM_x = 4; @@ -441,7 +441,7 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } } @@ -454,7 +454,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, int BLOCK_DIM = 1024; blockLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -464,7 +464,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 15) { @@ -475,7 +475,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 7) { @@ -486,7 +486,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else { @@ -497,7 +497,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } @@ -511,7 +511,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, int BLOCK_DIM = 1024; blockLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -521,7 +521,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -531,7 +531,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -541,7 +541,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else { int BLOCK_DIM_x = 4; @@ -551,7 +551,7 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 grid_dim(num_block_x, 1, 1); warpLaynormKernel - <<getCurrentStream()>>> + <<>> (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } } diff --git a/src/kernels/cuda/pad_slice.cu b/src/kernels/cuda/pad_slice.cu index 65c470dc..331f8e0d 100644 --- a/src/kernels/cuda/pad_slice.cu +++ b/src/kernels/cuda/pad_slice.cu @@ -49,7 +49,7 @@ __global__ void _pad_slice_kernel(void *part, void *whole, namespace infini { #define CASE(T) \ _pad_slice_kernel::t> \ - <<getCurrentStream()>>> \ + <<>> \ (partData, wholeData, metadata, nDims, num, isPad); #define SWITCH_DTYPE(DTYPE) \ diff --git a/src/kernels/cuda/reshape.cc b/src/kernels/cuda/reshape.cc index a42e86df..bbce222c 100644 --- a/src/kernels/cuda/reshape.cc +++ b/src/kernels/cuda/reshape.cc @@ -8,7 +8,7 @@ class CopyCuda : public CudaKernelWithoutConfig { auto outData = op->getOutputs()[0]->getRawDataPtr(); cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(), cudaMemcpyDeviceToDevice, - CUDAStream::p_CUDAStream->getCurrentStream()); + CUDAStream::getCurrentStream()); } }; // reshape/flatten/identity all act as copying from input to output. diff --git a/src/kernels/cuda/resize.cu b/src/kernels/cuda/resize.cu index b331578c..947ee7ce 100644 --- a/src/kernels/cuda/resize.cu +++ b/src/kernels/cuda/resize.cu @@ -214,7 +214,7 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData, IT_ASSERT(nearestMode < sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0])); _resize_kernel_nearest - <<getCurrentStream()>>> + <<>> (in, out, metaData, num, coordinateMode, nearestMode); } @@ -225,7 +225,7 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData, IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / sizeof(p_cooridnate_trans_mode_func[0])); _resize_kernel_linear_coeff - <<getCurrentStream()>>> + <<>> (in, out, metaData, num, coordinateMode); } @@ -236,7 +236,7 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData, IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / sizeof(p_cooridnate_trans_mode_func[0])); _resize_kernel_cubic_coeff - <<getCurrentStream()>>> + <<>> (in, out, metaData, num, coordinateMode); } } // namespace infini diff --git a/src/kernels/cuda/rope.cu b/src/kernels/cuda/rope.cu index a4701c77..8d35026f 100644 --- a/src/kernels/cuda/rope.cu +++ b/src/kernels/cuda/rope.cu @@ -36,9 +36,9 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo } -#define CASE(T) \ - _rope_kernel::t> \ - <<getCurrentStream()>>> \ +#define CASE(T) \ + _rope_kernel::t> \ + <<>> \ (pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride); #define SWITCH_DTYPE(DTYPE) \ diff --git a/src/kernels/cuda/softmax.cu b/src/kernels/cuda/softmax.cu index 8d24d393..2fe2d8a5 100644 --- a/src/kernels/cuda/softmax.cu +++ b/src/kernels/cuda/softmax.cu @@ -246,37 +246,37 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 64) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 32) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 16) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 4) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -286,7 +286,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -296,7 +296,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -306,7 +306,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else { int BLOCK_DIM_x = 4; @@ -316,7 +316,7 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } } @@ -328,37 +328,37 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 64) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 32) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 16) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 4) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 1024) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; @@ -368,7 +368,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -378,7 +378,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -388,7 +388,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } else { int BLOCK_DIM_x = 4; @@ -398,7 +398,7 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<getCurrentStream()>>> + <<>> (input, output, size, dimsize, stride); } } diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc index 94d946a7..df9dadfe 100644 --- a/src/kernels/cuda/split_concat.cc +++ b/src/kernels/cuda/split_concat.cc @@ -68,10 +68,10 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { _op->getInputs(1 - i)->getRawDataPtr(); auto outData = _op->getOutputs()[0]->getRawDataPtr(); - cudaMemcpyAsync( - outData, inData, _op->getInputs(1 - i)->getBytes(), - cudaMemcpyDeviceToDevice, - CUDAStream::p_CUDAStream->getCurrentStream()); + cudaMemcpyAsync(outData, inData, + _op->getInputs(1 - i)->getBytes(), + cudaMemcpyDeviceToDevice, + CUDAStream::getCurrentStream()); return; } } diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu index 0cbeb474..f14dc973 100644 --- a/src/kernels/cuda/split_concat.cu +++ b/src/kernels/cuda/split_concat.cu @@ -64,7 +64,7 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta, dim3 gridSize(gridDimX, batchSize); _split_concat_kernel - <<getCurrentStream()>>> + <<>> (eleMeta, compMeta, dim, nDims, isSplit); } void split_concat_kernel(const ElementTensorMetadata &eleMeta, @@ -79,7 +79,7 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta, dim3 gridSize(gridDimX, batchSize); _split_concat_kernel - <<getCurrentStream()>>> + <<>> (eleMeta, compMeta, dim, nDims, isSplit); } diff --git a/src/kernels/cuda/transpose.cu b/src/kernels/cuda/transpose.cu index 16c4d66c..833c1154 100644 --- a/src/kernels/cuda/transpose.cu +++ b/src/kernels/cuda/transpose.cu @@ -22,9 +22,9 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims, ((T *)output)[outputIdx] = ((T *)input)[inputIdx]; } } -#define CASE(T) \ - _transpose_kernel::t> \ - <<getCurrentStream()>>>\ +#define CASE(T) \ + _transpose_kernel::t> \ + <<>> \ (input, output, nDims, size, strides, outputShape); #define SWITCH_DTYPE(DTYPE) \ diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 0234a822..93a3cf6c 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -149,10 +149,10 @@ template void softmax_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _softmax_kernel1 - <<<1, 1, 0, CUDAStream::p_CUDAStream->getCurrentStream()>>> + <<<1, 1, 0, CUDAStream::getCurrentStream()>>> (input, output, num); _softmax_kernel2 - <<getCurrentStream()>>> + <<>> (input, output, num); } template void relu_kernel(T *input, T *output, size_t num) { @@ -160,7 +160,7 @@ template void relu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _relu_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } template void sigmoid_kernel(T *input, T *output, size_t num) { @@ -168,7 +168,7 @@ template void sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _sigmoid_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } template @@ -177,7 +177,7 @@ void hard_sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _hard_sigmoid_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } template void hard_swish_kernel(T *input, T *output, size_t num) { @@ -185,7 +185,7 @@ template void hard_swish_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _hard_swish_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } template void tanh_kernel(T *input, T *output, size_t num) { @@ -193,7 +193,7 @@ template void tanh_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _tanh_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } template void abs_kernel(T *input, T *output, size_t num) { @@ -201,7 +201,7 @@ template void abs_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _abs_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } template void sqrt_kernel(T *input, T *output, size_t num) { @@ -209,7 +209,7 @@ template void sqrt_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _sqrt_kernel - <<getCurrentStream()>>> + <<>> ((T *)input, (T *)output, num); } @@ -218,7 +218,7 @@ template void gelu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _gelu_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } @@ -227,7 +227,7 @@ template void silu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _silu_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } @@ -236,7 +236,7 @@ template void erf_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _erf_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } template void neg_kernel(T *input, T *output, size_t num) { @@ -244,7 +244,7 @@ template void neg_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _neg_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } @@ -344,7 +344,7 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); _cast_kernel - <<getCurrentStream()>>> + <<>> (input, output, num); } diff --git a/src/kernels/cuda/where.cu b/src/kernels/cuda/where.cu index 12af4aaf..f3c2459f 100644 --- a/src/kernels/cuda/where.cu +++ b/src/kernels/cuda/where.cu @@ -62,7 +62,7 @@ void whereKernel(const float *inputX, const float *inputY, } int gridsize = (outputsize + blocksize - 1) / blocksize; _whereKernel - <<getCurrentStream()>>>( + <<>>( inputX, inputY, condition, output, nDims, outputsize, inputXShape, inputYShape, conditionShape, outputShape, xSize, ySize, cSize); } @@ -87,7 +87,7 @@ void whereKernel(const half *inputX, const half *inputY, } int gridsize = (outputsize + blocksize - 1) / blocksize; _whereKernel - <<getCurrentStream()>>>( + <<>>( inputX, inputY, condition, output, nDims, outputsize, inputXShape, inputYShape, conditionShape, outputShape, xSize, ySize, cSize); }