[feature] add cudagraph support (#215)

* [feature] add cudagraph support

* modify code to pass the cuda_all_reduce test
This commit is contained in:
xiaonans 2024-02-21 14:00:25 +08:00 committed by GitHub
parent 900d8e58e3
commit 1c08ba200c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 362 additions and 129 deletions

View File

@ -5,6 +5,10 @@
#include <cstdint> #include <cstdint>
#include <iostream> #include <iostream>
#ifdef USE_CUDA
#include "cuda/cuda_runtime.h"
#endif
namespace infini { namespace infini {
class GraphHandlerObj { class GraphHandlerObj {
@ -137,6 +141,12 @@ class GraphHandlerObj {
inline void run() { g->getRuntime()->run(g); } inline void run() { g->getRuntime()->run(g); }
inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); } inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); }
#ifdef USE_CUDA
inline void run_with_cudagraph() {
(as<CudaRuntimeObj>(g->getRuntime()))->runWithCudaGraph(g);
}
#endif
}; };
} // namespace infini } // namespace infini

View File

@ -5,6 +5,7 @@
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include <cudnn.h> #include <cudnn.h>
#include <curand.h> #include <curand.h>
#include <memory>
#define checkCudaError(call) \ #define checkCudaError(call) \
if (auto err = call; err != cudaSuccess) \ if (auto err = call; err != cudaSuccess) \
@ -111,4 +112,20 @@ inline const char *curandGetErrorString(curandStatus_t error) {
using CudaPtr = void *; using CudaPtr = void *;
class CUDAStream {
public:
CUDAStream(const CUDAStream &) = delete;
CUDAStream(CUDAStream &&) = delete;
void operator=(const CUDAStream &) = delete;
void operator=(CUDAStream &&) = delete;
static cudaStream_t getCurrentStream() { return _stream; }
static void Init() { CUDAStream::_stream = 0; };
static void createStream() { checkCudaError(cudaStreamCreate(&_stream)); }
static void destroyStream() { checkCudaError(cudaStreamDestroy(_stream)); }
private:
CUDAStream(){};
static cudaStream_t _stream;
};
} // namespace infini } // namespace infini

View File

@ -14,6 +14,9 @@ class CudaRuntimeObj : public RuntimeObj {
std::unique_ptr<CommunicatorObj> comm; std::unique_ptr<CommunicatorObj> comm;
CudaPtr workspace; CudaPtr workspace;
size_t workspaceSize; size_t workspaceSize;
bool isCudaGraphCreated;
cudaGraph_t cudaGraph;
cudaGraphExec_t cudaGraphInstance;
public: public:
explicit CudaRuntimeObj(int deviceId = 0) explicit CudaRuntimeObj(int deviceId = 0)
@ -26,9 +29,16 @@ class CudaRuntimeObj : public RuntimeObj {
// size_t longformerNum = 3lu * (1 << 30); // size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB workspaceSize = 7ll << 30; // 7 GB
workspace = alloc(workspaceSize); workspace = alloc(workspaceSize);
isCudaGraphCreated = false;
CUDAStream::Init();
} }
virtual ~CudaRuntimeObj() { virtual ~CudaRuntimeObj() {
try { try {
if (isCudaGraphCreated) {
checkCudaError(cudaGraphExecDestroy(cudaGraphInstance));
checkCudaError(cudaGraphDestroy(cudaGraph));
CUDAStream::destroyStream();
}
dealloc(workspace); dealloc(workspace);
checkCudnnError(cudnnDestroy(cudnn)); checkCudnnError(cudnnDestroy(cudnn));
checkCublasError(cublasDestroy(cublas)); checkCublasError(cublasDestroy(cublas));
@ -75,6 +85,8 @@ class CudaRuntimeObj : public RuntimeObj {
void runWithoutSync(const Graph &graph) const; void runWithoutSync(const Graph &graph) const;
void runWithCudaGraph(const Graph &graph);
// init communicator // init communicator
void initComm(const string &name, int worldSize, int rank) final; void initComm(const string &name, int worldSize, int rank) final;

View File

@ -1376,6 +1376,9 @@ class OnnxStub:
def run(self) -> None: def run(self) -> None:
self.handler.run() self.handler.run()
def run_with_cudagraph(self) -> None:
self.handler.run_with_cudagraph()
def get_perf_time(self) -> float: def get_perf_time(self) -> float:
self.handler.get_perf_time() self.handler.get_perf_time()

View File

@ -19,7 +19,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) {
} }
namespace infini { namespace infini {
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
const auto &kernelRegistry = KernelRegistry::getInstance(); const auto &kernelRegistry = KernelRegistry::getInstance();
auto &perfEngine = PerfEngine::getInstance(); auto &perfEngine = PerfEngine::getInstance();
@ -39,6 +38,27 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
} }
} }
void CudaRuntimeObj::runWithCudaGraph(const Graph &graph) {
if (!isCudaGraphCreated) {
CUDAStream::createStream();
checkCudnnError(cudnnSetStream(cudnn, CUDAStream::getCurrentStream()));
checkCublasError(
cublasSetStream(cublas, CUDAStream::getCurrentStream()));
checkCudaError(cudaStreamBeginCapture(CUDAStream::getCurrentStream(),
cudaStreamCaptureModeGlobal));
runWithoutSync(graph);
checkCudaError(
cudaStreamEndCapture(CUDAStream::getCurrentStream(), &cudaGraph));
checkCudaError(
cudaGraphInstantiate(&cudaGraphInstance, cudaGraph, NULL, NULL, 0));
isCudaGraphCreated = true;
} else {
checkCudaError(
cudaGraphLaunch(cudaGraphInstance, CUDAStream::getCurrentStream()));
}
checkCudaError(cudaStreamSynchronize(CUDAStream::getCurrentStream()));
}
void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const {
const auto &kernelRegistry = KernelRegistry::getInstance(); const auto &kernelRegistry = KernelRegistry::getInstance();
auto &perfEngine = PerfEngine::getInstance(); auto &perfEngine = PerfEngine::getInstance();
@ -102,4 +122,5 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) {
#endif #endif
} }
cudaStream_t CUDAStream::_stream = 0;
} // namespace infini } // namespace infini

View File

@ -16,7 +16,8 @@ __global__ void cudaPrintFloatImpl(float *x, int len) {
namespace infini { namespace infini {
void cudaPrintFloat(float *x, int len) { void cudaPrintFloat(float *x, int len) {
cudaPrintFloatImpl<<<1, 1>>>(x, len); cudaPrintFloatImpl
<<<1, 1, 0, CUDAStream::getCurrentStream()>>>(x, len);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
} }

View File

@ -571,6 +571,10 @@ void init_graph_builder(py::module &m) {
.def("get_perf_time", &Handler::get_perf_time, policy::automatic) .def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("tune", &Handler::tune, policy::automatic) .def("tune", &Handler::tune, policy::automatic)
.def("run", &Handler::run, policy::automatic) .def("run", &Handler::run, policy::automatic)
#ifdef USE_CUDA
.def("run_with_cudagraph", &Handler::run_with_cudagraph,
policy::automatic)
#endif
.def("shape_infer", &Handler::shape_infer, policy::automatic) .def("shape_infer", &Handler::shape_infer, policy::automatic)
.def("change_shape", &Handler::change_shape, policy::automatic) .def("change_shape", &Handler::change_shape, policy::automatic)
.def("getDims", &Handler::getDims, policy::automatic) .def("getDims", &Handler::getDims, policy::automatic)

View File

@ -28,9 +28,8 @@ class AllReduceNCCL : public CudaKernelWithoutConfig {
ncclComm_t comm = ncclComm_t comm =
dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator()) dynamic_cast<NcclCommunicatorObj &>(context->getCommunicator())
.getNcclComm(); .getNcclComm();
// TODO: Using default stream 0 for now. checkNcclError(ncclAllReduce(input, output, count, ncclType, getRedOp(),
checkNcclError( comm, CUDAStream::getCurrentStream()));
ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0));
} }
virtual ncclRedOp_t getRedOp() const = 0; virtual ncclRedOp_t getRedOp() const = 0;

View File

@ -2,7 +2,7 @@
#include "cuda/cuda_attention_kvcache.h" #include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32 #define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE #define BLOCKSIZE WARP_SIZE
#define SEQ_UNIT 32 #define SEQ_UNIT 16
// ASSUME SEQ_LEN OF Q IS 1 // ASSUME SEQ_LEN OF Q IS 1
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
@ -103,7 +103,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
ptr_O[i] /= ptr_sum[0]; ptr_O[i] /= ptr_sum[0];
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; (float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
if(threadIdx.x == 0){ if(lane_id == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
} }
@ -157,13 +157,15 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y); dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
dim3 blockDim(BLOCKSIZE, 1); dim3 blockDim(BLOCKSIZE, 1);
assert(compMeta.dimSize[3] == 128); _attention_kvcache_kernel_128_1
_attention_kvcache_kernel_128_1<<<gridDim, blockDim>>>( <<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, (input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
compMeta, output_O_temp, output_sum_temp); compMeta, output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE>>>(
position_id, output_matmul, compMeta, output_O_temp, output_sum_temp); _attention_kvcache_kernel_128_2
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
0, CUDAStream::getCurrentStream()>>>
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
} }
} // namespace infini } // namespace infini

View File

@ -25,8 +25,9 @@ void clip_kernel(float *input, float *output, int num, float minValue,
float maxValue) { float maxValue) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_clip_kernel<<<gridsize, blocksize>>>(input, output, num, minValue, _clip_kernel
maxValue); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
input, output, num, minValue, maxValue);
} }
}; // namespace infini }; // namespace infini

View File

@ -131,8 +131,9 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
} }
#define CASE(OP, T) \ #define CASE(OP, T) \
_##OP##_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \ _##OP##_kernel<DT_CUDA<T>::t> \
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
#define SWITCH_DTYPE(OP, DTYPE) \ #define SWITCH_DTYPE(OP, DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \
@ -202,11 +203,13 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
if (dType == 1) { if (dType == 1) {
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, _pow_kernel<float>
b1, b2, b3, c0, c1, c2, c3); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 3) { } else if (dType == 3) {
_pow_kernel<int8_t><<<gridsize, blocksize>>>( _pow_kernel<int8_t>
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 10) { } else if (dType == 10) {
int a_size = a0 * a1 * a2 * a3; int a_size = a0 * a1 * a2 * a3;
int b_size = b0 * b1 * b2 * b3; int b_size = b0 * b1 * b2 * b3;
@ -220,8 +223,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
for (int i = 0; i < b_size; ++i) { for (int i = 0; i < b_size; ++i) {
b_float[i] = __half2float(((half *)b)[i]); b_float[i] = __half2float(((half *)b)[i]);
} }
_pow_kernel<float><<<gridsize, blocksize>>>( _pow_kernel<float>
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0, <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3); b1, b2, b3, c0, c1, c2, c3);
for (int i = 0; i < c_size; ++i) { for (int i = 0; i < c_size; ++i) {
((half *)c)[i] = __float2half(c_float[i]); ((half *)c)[i] = __float2half(c_float[i]);

View File

@ -42,7 +42,8 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
namespace infini { namespace infini {
#define CASE(T) \ #define CASE(T) \
_expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \ _expandKernel<DT_CUDA<T>::t><<<gridsize, blocksize, \
0, CUDAStream::getCurrentStream()>>>( \
input, output, nDims, outputsize, inputShape, outputShape); input, output, nDims, outputsize, inputShape, outputShape);
#define SWITCH_DTYPE(DTYPE) \ #define SWITCH_DTYPE(DTYPE) \

View File

@ -19,7 +19,8 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
int oSize) { int oSize) {
int blocksize = 32 * 16; int blocksize = 32 * 16;
int gridsize = (oSize + blocksize - 1) / blocksize; int gridsize = (oSize + blocksize - 1) / blocksize;
_extend_kernel<<<gridsize, blocksize>>>(in, out, blockSize, blockSizeOuter, _extend_kernel
oSize); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
in, out, blockSize, blockSizeOuter, oSize);
} }
} // namespace infini } // namespace infini

View File

@ -45,9 +45,12 @@ void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) {
int gridSize = (num + blockSize - 1) / blockSize; int gridSize = (num + blockSize - 1) / blockSize;
if (metaData.indexType == DataType::Int64) { if (metaData.indexType == DataType::Int64) {
_gather_kernel<T, int64_t> _gather_kernel<T, int64_t>
<<<gridSize, blockSize>>>(in, out, metaData, num); <<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
(in, out, metaData, num);
} else { } else {
_gather_kernel<T, int><<<gridSize, blockSize>>>(in, out, metaData, num); _gather_kernel<T, int>
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
(in, out, metaData, num);
} }
} }
template void gather_kernel<float>(float *in, float *out, template void gather_kernel<float>(float *in, float *out,

View File

@ -40,22 +40,26 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData,
int gridSize = (num + blockSize - 1) / blockSize; int gridSize = (num + blockSize - 1) / blockSize;
if (metaData.dataType == DataType::Float32 && if (metaData.dataType == DataType::Float32 &&
metaData.indexType == DataType::Int64) { metaData.indexType == DataType::Int64) {
_gather_elements_kernel<float, int64_t><<<gridSize, blockSize>>>( _gather_elements_kernel<float, int64_t>
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out), reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
metaData, num); metaData, num);
} else if (metaData.dataType == DataType::Int32 && } else if (metaData.dataType == DataType::Int32 &&
metaData.indexType == DataType::Int64) { metaData.indexType == DataType::Int64) {
_gather_elements_kernel<int, int64_t><<<gridSize, blockSize>>>( _gather_elements_kernel<int, int64_t>
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData, reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
num); num);
} else if (metaData.dataType == DataType::Float32 && } else if (metaData.dataType == DataType::Float32 &&
metaData.indexType == DataType::Int32) { metaData.indexType == DataType::Int32) {
_gather_elements_kernel<float, int><<<gridSize, blockSize>>>( _gather_elements_kernel<float, int>
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
reinterpret_cast<float *>(in), reinterpret_cast<float *>(out), reinterpret_cast<float *>(in), reinterpret_cast<float *>(out),
metaData, num); metaData, num);
} else if (metaData.dataType == DataType::Int32 && } else if (metaData.dataType == DataType::Int32 &&
metaData.indexType == DataType::Int32) { metaData.indexType == DataType::Int32) {
_gather_elements_kernel<int, int><<<gridSize, blockSize>>>( _gather_elements_kernel<int, int>
<<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>(
reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData, reinterpret_cast<int *>(in), reinterpret_cast<int *>(out), metaData,
num); num);
} else { } else {

View File

@ -344,8 +344,8 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<float, 1024> blockLaynormKernel<float, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output, <<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
eps, scaleSize, bias, biasSize); (input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32; int BLOCK_DIM_y = 32;
@ -353,9 +353,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 32, 32>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
bias, biasSize); (input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64; int BLOCK_DIM_y = 64;
@ -363,8 +364,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 16, 64>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
@ -373,8 +375,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 8, 128>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
@ -383,8 +386,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 4, 256>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} }
} }
@ -396,8 +400,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
if (dimsize > 1024) { if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<float, 1024><<<num_block, BLOCK_DIM>>>( blockLaynormKernel<float, 1024>
input, scale, dimsize, stride, output, eps, scaleSize); <<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32; int BLOCK_DIM_y = 32;
@ -405,8 +410,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 32, 32>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64; int BLOCK_DIM_y = 64;
@ -414,8 +420,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 16, 64>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128; int BLOCK_DIM_y = 128;
@ -423,8 +430,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 8, 128>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256; int BLOCK_DIM_y = 256;
@ -432,8 +440,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<float, 4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<float, 4, 256>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} }
} }
//----------------- //-----------------
@ -445,8 +454,8 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024> blockLaynormKernel<half, 1024>
<<<num_block, BLOCK_DIM>>>(input, scale, dimsize, stride, output, <<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
eps, scaleSize, bias, biasSize); (input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32; int BLOCK_DIM_y = 32;
@ -454,8 +463,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 32, 32>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
@ -464,8 +474,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 16, 64>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
@ -474,8 +485,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 8, 128>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
@ -484,8 +496,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 4, 256>
input, scale, dimsize, stride, output, eps, scaleSize, num_block, <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block,
bias, biasSize); bias, biasSize);
} }
} }
@ -497,8 +510,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
if (dimsize > 1024) { if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
blockLaynormKernel<half, 1024><<<num_block, BLOCK_DIM>>>( blockLaynormKernel<half, 1024>
input, scale, dimsize, stride, output, eps, scaleSize); <<<num_block, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32; int BLOCK_DIM_y = 32;
@ -506,8 +520,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 32, 32><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 32, 32>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64; int BLOCK_DIM_y = 64;
@ -515,8 +530,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 16, 64><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 16, 64>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128; int BLOCK_DIM_y = 128;
@ -524,8 +540,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 8, 128><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 8, 128>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256; int BLOCK_DIM_y = 256;
@ -533,8 +550,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps,
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
warpLaynormKernel<half, 4, 256><<<grid_dim, block_dim>>>( warpLaynormKernel<half, 4, 256>
input, scale, dimsize, stride, output, eps, scaleSize, num_block); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, scale, dimsize, stride, output, eps, scaleSize, num_block);
} }
} }
} // namespace infini } // namespace infini

View File

@ -48,8 +48,9 @@ __global__ void _pad_slice_kernel(void *part, void *whole,
namespace infini { namespace infini {
#define CASE(T) \ #define CASE(T) \
_pad_slice_kernel<DT_CUDA<T>::t><<<gridSize, blockSize>>>( \ _pad_slice_kernel<DT_CUDA<T>::t> \
partData, wholeData, metadata, nDims, num, isPad); <<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>> \
(partData, wholeData, metadata, nDims, num, isPad);
#define SWITCH_DTYPE(DTYPE) \ #define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \

View File

@ -7,7 +7,8 @@ class CopyCuda : public CudaKernelWithoutConfig {
auto inData = op->getInputs(0)->getRawDataPtr<void *>(); auto inData = op->getInputs(0)->getRawDataPtr<void *>();
auto outData = op->getOutputs()[0]->getRawDataPtr<void *>(); auto outData = op->getOutputs()[0]->getRawDataPtr<void *>();
cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(), cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(),
cudaMemcpyDeviceToDevice); cudaMemcpyDeviceToDevice,
CUDAStream::getCurrentStream());
} }
}; };
// reshape/flatten/identity all act as copying from input to output. // reshape/flatten/identity all act as copying from input to output.

View File

@ -213,8 +213,9 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
sizeof(p_cooridnate_trans_mode_func[0])); sizeof(p_cooridnate_trans_mode_func[0]));
IT_ASSERT(nearestMode < IT_ASSERT(nearestMode <
sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0])); sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0]));
_resize_kernel_nearest<<<gridsize, blocksize>>>( _resize_kernel_nearest
in, out, metaData, num, coordinateMode, nearestMode); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(in, out, metaData, num, coordinateMode, nearestMode);
} }
void resize_kernel_linear(float *in, float *out, const MetaData &metaData, void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
@ -223,8 +224,9 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
auto gridsize = (num + blocksize - 1) / blocksize; auto gridsize = (num + blocksize - 1) / blocksize;
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
sizeof(p_cooridnate_trans_mode_func[0])); sizeof(p_cooridnate_trans_mode_func[0]));
_resize_kernel_linear_coeff<<<gridsize, blocksize>>>(in, out, metaData, num, _resize_kernel_linear_coeff
coordinateMode); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(in, out, metaData, num, coordinateMode);
} }
void resize_kernel_cubic(float *in, float *out, const MetaData &metaData, void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
@ -233,7 +235,8 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
auto gridsize = (num + blocksize - 1) / blocksize; auto gridsize = (num + blocksize - 1) / blocksize;
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
sizeof(p_cooridnate_trans_mode_func[0])); sizeof(p_cooridnate_trans_mode_func[0]));
_resize_kernel_cubic_coeff<<<gridsize, blocksize>>>(in, out, metaData, num, _resize_kernel_cubic_coeff
coordinateMode); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(in, out, metaData, num, coordinateMode);
} }
} // namespace infini } // namespace infini

View File

@ -9,7 +9,8 @@ constexpr int block_work_size() { return thread_work_size() * num_threads(); }
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1) // gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
template <class T> template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
int dim_head, int hidden_stride, int pos_stride) {
int batch_id = blockIdx.x; int batch_id = blockIdx.x;
int target_pos = pos[batch_id * pos_stride + blockIdx.y]; int target_pos = pos[batch_id * pos_stride + blockIdx.y];
int ith = blockIdx.z * blockDim.x + threadIdx.x; int ith = blockIdx.z * blockDim.x + threadIdx.x;
@ -36,8 +37,9 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
#define CASE(T) \ #define CASE(T) \
_rope_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \ _rope_kernel<DT_CUDA<T>::t> \
pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
#define SWITCH_DTYPE(DTYPE) \ #define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \
@ -82,7 +84,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
} }
namespace infini { namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { void rope_kernel(int dType, int * pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
dim3 blocksize = dim3(1024,1,1); dim3 blocksize = dim3(1024,1,1);
dim3 gridsize = dim3(1, 1, 4); dim3 gridsize = dim3(1, 1, 4);
SWITCH_DTYPE(dType) SWITCH_DTYPE(dType)

View File

@ -246,32 +246,38 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024> _blockSoftmaxKernel<float, 1024>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 64) { } else if (dimsize > 1024 * 64) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 128> _blockSoftmaxKernel<float, 1024, 128>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 32) { } else if (dimsize > 1024 * 32) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 64> _blockSoftmaxKernel<float, 1024, 64>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 16) { } else if (dimsize > 1024 * 16) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 32> _blockSoftmaxKernel<float, 1024, 32>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 4) { } else if (dimsize > 1024 * 4) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 16> _blockSoftmaxKernel<float, 1024, 16>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024) { } else if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<float, 1024, 4> _blockSoftmaxKernel<float, 1024, 4>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32; int BLOCK_DIM_y = 32;
@ -280,7 +286,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<float, 32, 32, 32> _warpSoftmaxKernel<float, 32, 32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64; int BLOCK_DIM_y = 64;
@ -289,7 +296,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<float, 16, 64, 2> _warpSoftmaxKernel<float, 16, 64, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128; int BLOCK_DIM_y = 128;
@ -298,7 +306,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<float, 8, 128, 2> _warpSoftmaxKernel<float, 8, 128, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256; int BLOCK_DIM_y = 256;
@ -307,7 +316,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<float, 4, 256, 2> _warpSoftmaxKernel<float, 4, 256, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} }
} }
//------------------ //------------------
@ -318,32 +328,38 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024> _blockSoftmaxKernel<half, 1024>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 64) { } else if (dimsize > 1024 * 64) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 128> _blockSoftmaxKernel<half, 1024, 128>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 32) { } else if (dimsize > 1024 * 32) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 64> _blockSoftmaxKernel<half, 1024, 64>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 16) { } else if (dimsize > 1024 * 16) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 32> _blockSoftmaxKernel<half, 1024, 32>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024 * 4) { } else if (dimsize > 1024 * 4) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 16> _blockSoftmaxKernel<half, 1024, 16>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 1024) { } else if (dimsize > 1024) {
int BLOCK_DIM = 1024; int BLOCK_DIM = 1024;
_blockSoftmaxKernel<half, 1024, 4> _blockSoftmaxKernel<half, 1024, 4>
<<<num_blocks, BLOCK_DIM>>>(input, output, size, dimsize, stride); <<<num_blocks, BLOCK_DIM, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 31) { } else if (dimsize > 31) {
int BLOCK_DIM_x = 32; int BLOCK_DIM_x = 32;
int BLOCK_DIM_y = 32; int BLOCK_DIM_y = 32;
@ -352,7 +368,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 32, 32, 32> _warpSoftmaxKernel<half, 32, 32, 32>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 15) { } else if (dimsize > 15) {
int BLOCK_DIM_x = 16; int BLOCK_DIM_x = 16;
int BLOCK_DIM_y = 64; int BLOCK_DIM_y = 64;
@ -361,7 +378,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 16, 64, 2> _warpSoftmaxKernel<half, 16, 64, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else if (dimsize > 7) { } else if (dimsize > 7) {
int BLOCK_DIM_x = 8; int BLOCK_DIM_x = 8;
int BLOCK_DIM_y = 128; int BLOCK_DIM_y = 128;
@ -370,7 +388,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 8, 128, 2> _warpSoftmaxKernel<half, 8, 128, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} else { } else {
int BLOCK_DIM_x = 4; int BLOCK_DIM_x = 4;
int BLOCK_DIM_y = 256; int BLOCK_DIM_y = 256;
@ -379,7 +398,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size,
dim3 grid_dim(num_block_x, 1, 1); dim3 grid_dim(num_block_x, 1, 1);
_warpSoftmaxKernel<half, 4, 256, 2> _warpSoftmaxKernel<half, 4, 256, 2>
<<<grid_dim, block_dim>>>(input, output, size, dimsize, stride); <<<grid_dim, block_dim, 0, CUDAStream::getCurrentStream()>>>
(input, output, size, dimsize, stride);
} }
} }
} // namespace infini } // namespace infini

View File

@ -70,7 +70,8 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
_op->getOutputs()[0]->getRawDataPtr<void *>(); _op->getOutputs()[0]->getRawDataPtr<void *>();
cudaMemcpyAsync(outData, inData, cudaMemcpyAsync(outData, inData,
_op->getInputs(1 - i)->getBytes(), _op->getInputs(1 - i)->getBytes(),
cudaMemcpyDeviceToDevice); cudaMemcpyDeviceToDevice,
CUDAStream::getCurrentStream());
return; return;
} }
} }

View File

@ -63,8 +63,9 @@ void split_concat_kernel(const ElementTensorMetadata<float> &eleMeta,
// each y is a split among the batch // each y is a split among the batch
dim3 gridSize(gridDimX, batchSize); dim3 gridSize(gridDimX, batchSize);
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims, _split_concat_kernel
isSplit); <<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
(eleMeta, compMeta, dim, nDims, isSplit);
} }
void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta, void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
const ComposedTensorMetadata<half> &compMeta, int dim, const ComposedTensorMetadata<half> &compMeta, int dim,
@ -77,8 +78,9 @@ void split_concat_kernel(const ElementTensorMetadata<half> &eleMeta,
// each y is a split among the batch // each y is a split among the batch
dim3 gridSize(gridDimX, batchSize); dim3 gridSize(gridDimX, batchSize);
_split_concat_kernel<<<gridSize, blockSize>>>(eleMeta, compMeta, dim, nDims, _split_concat_kernel
isSplit); <<<gridSize, blockSize, 0, CUDAStream::getCurrentStream()>>>
(eleMeta, compMeta, dim, nDims, isSplit);
} }
} // namespace infini } // namespace infini

View File

@ -23,8 +23,9 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims,
} }
} }
#define CASE(T) \ #define CASE(T) \
_transpose_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \ _transpose_kernel<DT_CUDA<T>::t> \
input, output, nDims, size, strides, outputShape); <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(input, output, nDims, size, strides, outputShape);
#define SWITCH_DTYPE(DTYPE) \ #define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \

View File

@ -148,78 +148,104 @@ template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_softmax_kernel1<T><<<1, 1>>>(input, output, num); _softmax_kernel1<T>
_softmax_kernel2<T><<<gridsize, blocksize>>>(input, output, num); <<<1, 1, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
_softmax_kernel2<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void relu_kernel(T *input, T *output, size_t num) { template <typename T> void relu_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_relu_kernel<T><<<gridsize, blocksize>>>(input, output, num); _relu_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) { template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_sigmoid_kernel<T><<<gridsize, blocksize>>>(input, output, num); _sigmoid_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> template <typename T>
void hard_sigmoid_kernel(T *input, T *output, size_t num) { void hard_sigmoid_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_hard_sigmoid_kernel<T><<<gridsize, blocksize>>>(input, output, num); _hard_sigmoid_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) { template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_hard_swish_kernel<T><<<gridsize, blocksize>>>(input, output, num); _hard_swish_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void tanh_kernel(T *input, T *output, size_t num) { template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_tanh_kernel<T><<<gridsize, blocksize>>>(input, output, num); _tanh_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void abs_kernel(T *input, T *output, size_t num) { template <typename T> void abs_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_abs_kernel<T><<<gridsize, blocksize>>>(input, output, num); _abs_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void sqrt_kernel(T *input, T *output, size_t num) { template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_sqrt_kernel<<<gridsize, blocksize>>>((T *)input, (T *)output, num); _sqrt_kernel
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
((T *)input, (T *)output, num);
} }
template <typename T> void gelu_kernel(T *input, T *output, size_t num) { template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_gelu_kernel<T><<<gridsize, blocksize>>>(input, output, num); _gelu_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void silu_kernel(T *input, T *output, size_t num) { template <typename T> void silu_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_silu_kernel<T><<<gridsize, blocksize>>>(input, output, num); _silu_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void erf_kernel(T *input, T *output, size_t num) { template <typename T> void erf_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_erf_kernel<T><<<gridsize, blocksize>>>(input, output, num); _erf_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template <typename T> void neg_kernel(T *input, T *output, size_t num) { template <typename T> void neg_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_neg_kernel<T><<<gridsize, blocksize>>>(input, output, num); _neg_kernel<T>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
void unary_kernel(const Operator &_op) { void unary_kernel(const Operator &_op) {
@ -317,7 +343,9 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_cast_kernel<INPUT, OUTPUT><<<gridsize, blocksize>>>(input, output, num); _cast_kernel<INPUT, OUTPUT>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(input, output, num);
} }
template void cast_kernel<float, half>(float *input, half *output, size_t num); template void cast_kernel<float, half>(float *input, half *output, size_t num);

View File

@ -61,7 +61,8 @@ void whereKernel(const float *inputX, const float *inputY,
blocksize = 32; blocksize = 32;
} }
int gridsize = (outputsize + blocksize - 1) / blocksize; int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<float><<<gridsize, blocksize>>>( _whereKernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape, inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize); inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
} }
@ -85,7 +86,8 @@ void whereKernel(const half *inputX, const half *inputY,
blocksize = 32; blocksize = 32;
} }
int gridsize = (outputsize + blocksize - 1) / blocksize; int gridsize = (outputsize + blocksize - 1) / blocksize;
_whereKernel<half><<<gridsize, blocksize>>>( _whereKernel<half>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
inputX, inputY, condition, output, nDims, outputsize, inputXShape, inputX, inputY, condition, output, nDims, outputsize, inputXShape,
inputYShape, conditionShape, outputShape, xSize, ySize, cSize); inputYShape, conditionShape, outputShape, xSize, ySize, cSize);
} }

View File

@ -0,0 +1,70 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention_kvcache.h"
#include "test.h"
namespace infini {
TEST(TestCudaRuntime, CudaGraph) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
auto op = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
position_id_d, nullptr);
auto op1 = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, op->getOutputs()[0], input_k_d,
input_v_d, position_id_d, nullptr);
auto op2 = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, op1->getOutputs()[0], input_k_d,
input_v_d, position_id_d, nullptr);
gCuda->dataMalloc();
input_q_d->setData(OneGenerator());
input_k_d->setData(OneGenerator());
input_v_d->setData(OneGenerator());
position_id_d->setData(IncrementalGenerator());
cudaRuntime->run(gCuda);
cudaEvent_t start, stop;
float milliseconds_1 = 0, milliseconds_2 = 0;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaDeviceSynchronize();
cudaEventRecord(start);
cudaRuntime->run(gCuda);
cudaEventRecord(stop);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&milliseconds_1, start, stop);
printf("without cudaGraph, latency: %f ms\n", milliseconds_1);
cudaRuntime->runWithCudaGraph(gCuda);
cudaRuntime->runWithCudaGraph(gCuda);
cudaDeviceSynchronize();
cudaEventRecord(start);
cudaRuntime->runWithCudaGraph(gCuda);
cudaEventRecord(stop);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&milliseconds_2, start, stop);
printf("with cudaGraph, latency: %f ms\n", milliseconds_2);
EXPECT_GE(milliseconds_1, milliseconds_2);
}
} // namespace infini