From 1c08ba200c245761586c7d772dd1ff3e1502632e Mon Sep 17 00:00:00 2001 From: xiaonans <51065160+xiaonans@users.noreply.github.com> Date: Wed, 21 Feb 2024 14:00:25 +0800 Subject: [PATCH] [feature] add cudagraph support (#215) * [feature] add cudagraph support * modify code to pass the cuda_all_reduce test --- include/core/graph_handler.h | 10 +++ include/cuda/cuda_common.h | 17 ++++ include/cuda/cuda_runtime.h | 12 +++ pyinfinitensor/src/pyinfinitensor/onnx.py | 3 + src/cuda/cuda_runtime.cc | 23 ++++- src/cuda/cuda_utility.cu | 3 +- src/ffi/ffi_infinitensor.cc | 4 + src/kernels/cuda/all_reduce.cc | 5 +- src/kernels/cuda/attention_kvcache.cu | 18 ++-- src/kernels/cuda/clip.cu | 5 +- src/kernels/cuda/element_wise.cu | 20 +++-- src/kernels/cuda/expand.cu | 3 +- src/kernels/cuda/extend.cu | 5 +- src/kernels/cuda/gather.cu | 7 +- src/kernels/cuda/gather_elements.cu | 12 ++- src/kernels/cuda/layer_norm.cu | 100 +++++++++++++--------- src/kernels/cuda/pad_slice.cu | 5 +- src/kernels/cuda/reshape.cc | 3 +- src/kernels/cuda/resize.cu | 15 ++-- src/kernels/cuda/rope.cu | 11 ++- src/kernels/cuda/softmax.cu | 60 ++++++++----- src/kernels/cuda/split_concat.cc | 3 +- src/kernels/cuda/split_concat.cu | 10 ++- src/kernels/cuda/transpose.cu | 5 +- src/kernels/cuda/unary.cu | 56 +++++++++--- src/kernels/cuda/where.cu | 6 +- test/cuda/test_cudagraph.cc | 70 +++++++++++++++ 27 files changed, 362 insertions(+), 129 deletions(-) create mode 100644 test/cuda/test_cudagraph.cc diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 36486e36..ce455d62 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -5,6 +5,10 @@ #include #include +#ifdef USE_CUDA +#include "cuda/cuda_runtime.h" +#endif + namespace infini { class GraphHandlerObj { @@ -137,6 +141,12 @@ class GraphHandlerObj { inline void run() { g->getRuntime()->run(g); } inline double get_perf_time() { return g->getRuntime()->getPerfTime(g); } + +#ifdef USE_CUDA + inline void run_with_cudagraph() { + (as(g->getRuntime()))->runWithCudaGraph(g); + } +#endif }; } // namespace infini diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index 4eb75f27..a024d4f3 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -5,6 +5,7 @@ #include #include #include +#include #define checkCudaError(call) \ if (auto err = call; err != cudaSuccess) \ @@ -111,4 +112,20 @@ inline const char *curandGetErrorString(curandStatus_t error) { using CudaPtr = void *; +class CUDAStream { + public: + CUDAStream(const CUDAStream &) = delete; + CUDAStream(CUDAStream &&) = delete; + void operator=(const CUDAStream &) = delete; + void operator=(CUDAStream &&) = delete; + static cudaStream_t getCurrentStream() { return _stream; } + static void Init() { CUDAStream::_stream = 0; }; + static void createStream() { checkCudaError(cudaStreamCreate(&_stream)); } + static void destroyStream() { checkCudaError(cudaStreamDestroy(_stream)); } + + private: + CUDAStream(){}; + static cudaStream_t _stream; +}; + } // namespace infini diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index 19fd9fc8..7d48c019 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -14,6 +14,9 @@ class CudaRuntimeObj : public RuntimeObj { std::unique_ptr comm; CudaPtr workspace; size_t workspaceSize; + bool isCudaGraphCreated; + cudaGraph_t cudaGraph; + cudaGraphExec_t cudaGraphInstance; public: explicit CudaRuntimeObj(int deviceId = 0) @@ -26,9 +29,16 @@ class CudaRuntimeObj : public RuntimeObj { // size_t longformerNum = 3lu * (1 << 30); workspaceSize = 7ll << 30; // 7 GB workspace = alloc(workspaceSize); + isCudaGraphCreated = false; + CUDAStream::Init(); } virtual ~CudaRuntimeObj() { try { + if (isCudaGraphCreated) { + checkCudaError(cudaGraphExecDestroy(cudaGraphInstance)); + checkCudaError(cudaGraphDestroy(cudaGraph)); + CUDAStream::destroyStream(); + } dealloc(workspace); checkCudnnError(cudnnDestroy(cudnn)); checkCublasError(cublasDestroy(cublas)); @@ -75,6 +85,8 @@ class CudaRuntimeObj : public RuntimeObj { void runWithoutSync(const Graph &graph) const; + void runWithCudaGraph(const Graph &graph); + // init communicator void initComm(const string &name, int worldSize, int rank) final; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 58993519..1a2e28a7 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -1376,6 +1376,9 @@ class OnnxStub: def run(self) -> None: self.handler.run() + def run_with_cudagraph(self) -> None: + self.handler.run_with_cudagraph() + def get_perf_time(self) -> float: self.handler.get_perf_time() diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index b92cb18f..9bab5018 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -19,7 +19,6 @@ void CHECK_CUDA_KERNEL_ERROR(infini::Operator op) { } namespace infini { - void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { const auto &kernelRegistry = KernelRegistry::getInstance(); auto &perfEngine = PerfEngine::getInstance(); @@ -39,6 +38,27 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { } } +void CudaRuntimeObj::runWithCudaGraph(const Graph &graph) { + if (!isCudaGraphCreated) { + CUDAStream::createStream(); + checkCudnnError(cudnnSetStream(cudnn, CUDAStream::getCurrentStream())); + checkCublasError( + cublasSetStream(cublas, CUDAStream::getCurrentStream())); + checkCudaError(cudaStreamBeginCapture(CUDAStream::getCurrentStream(), + cudaStreamCaptureModeGlobal)); + runWithoutSync(graph); + checkCudaError( + cudaStreamEndCapture(CUDAStream::getCurrentStream(), &cudaGraph)); + checkCudaError( + cudaGraphInstantiate(&cudaGraphInstance, cudaGraph, NULL, NULL, 0)); + isCudaGraphCreated = true; + } else { + checkCudaError( + cudaGraphLaunch(cudaGraphInstance, CUDAStream::getCurrentStream())); + } + checkCudaError(cudaStreamSynchronize(CUDAStream::getCurrentStream())); +} + void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { const auto &kernelRegistry = KernelRegistry::getInstance(); auto &perfEngine = PerfEngine::getInstance(); @@ -102,4 +122,5 @@ void CudaRuntimeObj::initComm(const string &name, int worldSize, int rank) { #endif } +cudaStream_t CUDAStream::_stream = 0; } // namespace infini diff --git a/src/cuda/cuda_utility.cu b/src/cuda/cuda_utility.cu index e38910b9..83cee26c 100644 --- a/src/cuda/cuda_utility.cu +++ b/src/cuda/cuda_utility.cu @@ -16,7 +16,8 @@ __global__ void cudaPrintFloatImpl(float *x, int len) { namespace infini { void cudaPrintFloat(float *x, int len) { - cudaPrintFloatImpl<<<1, 1>>>(x, len); + cudaPrintFloatImpl + <<<1, 1, 0, CUDAStream::getCurrentStream()>>>(x, len); cudaDeviceSynchronize(); } diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 41200933..9dc43510 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -571,6 +571,10 @@ void init_graph_builder(py::module &m) { .def("get_perf_time", &Handler::get_perf_time, policy::automatic) .def("tune", &Handler::tune, policy::automatic) .def("run", &Handler::run, policy::automatic) +#ifdef USE_CUDA + .def("run_with_cudagraph", &Handler::run_with_cudagraph, + policy::automatic) +#endif .def("shape_infer", &Handler::shape_infer, policy::automatic) .def("change_shape", &Handler::change_shape, policy::automatic) .def("getDims", &Handler::getDims, policy::automatic) diff --git a/src/kernels/cuda/all_reduce.cc b/src/kernels/cuda/all_reduce.cc index 8b64d2ab..e77c98b6 100644 --- a/src/kernels/cuda/all_reduce.cc +++ b/src/kernels/cuda/all_reduce.cc @@ -28,9 +28,8 @@ class AllReduceNCCL : public CudaKernelWithoutConfig { ncclComm_t comm = dynamic_cast(context->getCommunicator()) .getNcclComm(); - // TODO: Using default stream 0 for now. - checkNcclError( - ncclAllReduce(input, output, count, ncclType, getRedOp(), comm, 0)); + checkNcclError(ncclAllReduce(input, output, count, ncclType, getRedOp(), + comm, CUDAStream::getCurrentStream())); } virtual ncclRedOp_t getRedOp() const = 0; diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index f169a4b1..476220db 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -2,7 +2,7 @@ #include "cuda/cuda_attention_kvcache.h" #define WARP_SIZE 32 #define BLOCKSIZE WARP_SIZE -#define SEQ_UNIT 32 +#define SEQ_UNIT 16 // ASSUME SEQ_LEN OF Q IS 1 __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, @@ -103,7 +103,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, ptr_O[i] /= ptr_sum[0]; (float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; - if(threadIdx.x == 0){ + if(lane_id == 0){ output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; } @@ -157,13 +157,15 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y); dim3 blockDim(BLOCKSIZE, 1); - assert(compMeta.dimSize[3] == 128); - _attention_kvcache_kernel_128_1<<>>( - input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, + _attention_kvcache_kernel_128_1 + <<>> + (input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, compMeta, output_O_temp, output_sum_temp); - _attention_kvcache_kernel_128_2<<>>( - position_id, output_matmul, compMeta, output_O_temp, output_sum_temp); - + + _attention_kvcache_kernel_128_2 + <<>> + (position_id, output_matmul, compMeta, output_O_temp, output_sum_temp); } } // namespace infini diff --git a/src/kernels/cuda/clip.cu b/src/kernels/cuda/clip.cu index 7d3e97bd..85b096b0 100644 --- a/src/kernels/cuda/clip.cu +++ b/src/kernels/cuda/clip.cu @@ -25,8 +25,9 @@ void clip_kernel(float *input, float *output, int num, float minValue, float maxValue) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _clip_kernel<<>>(input, output, num, minValue, - maxValue); + _clip_kernel + <<>>( + input, output, num, minValue, maxValue); } }; // namespace infini diff --git a/src/kernels/cuda/element_wise.cu b/src/kernels/cuda/element_wise.cu index 98a12571..e1b68699 100644 --- a/src/kernels/cuda/element_wise.cu +++ b/src/kernels/cuda/element_wise.cu @@ -131,8 +131,9 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2, } #define CASE(OP, T) \ - _##OP##_kernel::t><<>>( \ - a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); + _##OP##_kernel::t> \ + <<>> \ + (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); #define SWITCH_DTYPE(OP, DTYPE) \ switch (DTYPE) { \ @@ -202,11 +203,13 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, int num = c0 * c1 * c2 * c3; int gridsize = (num + block_work_size() - 1) / block_work_size(); if (dType == 1) { - _pow_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, - b1, b2, b3, c0, c1, c2, c3); + _pow_kernel + <<>> + (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); } else if (dType == 3) { - _pow_kernel<<>>( - a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); + _pow_kernel + <<>> + (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); } else if (dType == 10) { int a_size = a0 * a1 * a2 * a3; int b_size = b0 * b1 * b2 * b3; @@ -220,8 +223,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, for (int i = 0; i < b_size; ++i) { b_float[i] = __half2float(((half *)b)[i]); } - _pow_kernel<<>>( - a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0, + _pow_kernel + <<>> + (a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); for (int i = 0; i < c_size; ++i) { ((half *)c)[i] = __float2half(c_float[i]); diff --git a/src/kernels/cuda/expand.cu b/src/kernels/cuda/expand.cu index af92b9ce..5e22be44 100644 --- a/src/kernels/cuda/expand.cu +++ b/src/kernels/cuda/expand.cu @@ -42,7 +42,8 @@ __global__ void _expandKernel(void *input, void *output, int nDims, namespace infini { #define CASE(T) \ - _expandKernel::t><<>>( \ + _expandKernel::t><<>>( \ input, output, nDims, outputsize, inputShape, outputShape); #define SWITCH_DTYPE(DTYPE) \ diff --git a/src/kernels/cuda/extend.cu b/src/kernels/cuda/extend.cu index f6879105..3fce9922 100644 --- a/src/kernels/cuda/extend.cu +++ b/src/kernels/cuda/extend.cu @@ -19,7 +19,8 @@ void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter, int oSize) { int blocksize = 32 * 16; int gridsize = (oSize + blocksize - 1) / blocksize; - _extend_kernel<<>>(in, out, blockSize, blockSizeOuter, - oSize); + _extend_kernel + <<>>( + in, out, blockSize, blockSizeOuter, oSize); } } // namespace infini diff --git a/src/kernels/cuda/gather.cu b/src/kernels/cuda/gather.cu index c9dedd95..7b2d9dbf 100644 --- a/src/kernels/cuda/gather.cu +++ b/src/kernels/cuda/gather.cu @@ -45,9 +45,12 @@ void gather_kernel(T *in, T *out, GatherMetaData metaData, size_t num) { int gridSize = (num + blockSize - 1) / blockSize; if (metaData.indexType == DataType::Int64) { _gather_kernel - <<>>(in, out, metaData, num); + <<>> + (in, out, metaData, num); } else { - _gather_kernel<<>>(in, out, metaData, num); + _gather_kernel + <<>> + (in, out, metaData, num); } } template void gather_kernel(float *in, float *out, diff --git a/src/kernels/cuda/gather_elements.cu b/src/kernels/cuda/gather_elements.cu index 0b7817eb..545820b4 100644 --- a/src/kernels/cuda/gather_elements.cu +++ b/src/kernels/cuda/gather_elements.cu @@ -40,22 +40,26 @@ void gather_elements_kernel(void *in, void *out, GatherMetaData metaData, int gridSize = (num + blockSize - 1) / blockSize; if (metaData.dataType == DataType::Float32 && metaData.indexType == DataType::Int64) { - _gather_elements_kernel<<>>( + _gather_elements_kernel + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else if (metaData.dataType == DataType::Int32 && metaData.indexType == DataType::Int64) { - _gather_elements_kernel<<>>( + _gather_elements_kernel + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else if (metaData.dataType == DataType::Float32 && metaData.indexType == DataType::Int32) { - _gather_elements_kernel<<>>( + _gather_elements_kernel + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else if (metaData.dataType == DataType::Int32 && metaData.indexType == DataType::Int32) { - _gather_elements_kernel<<>>( + _gather_elements_kernel + <<>>( reinterpret_cast(in), reinterpret_cast(out), metaData, num); } else { diff --git a/src/kernels/cuda/layer_norm.cu b/src/kernels/cuda/layer_norm.cu index 26f06e28..b3d74c77 100644 --- a/src/kernels/cuda/layer_norm.cu +++ b/src/kernels/cuda/layer_norm.cu @@ -344,8 +344,8 @@ void LaynormKernel(const float *input, const float *scale, const float eps, int BLOCK_DIM = 1024; blockLaynormKernel - <<>>(input, scale, dimsize, stride, output, - eps, scaleSize, bias, biasSize); + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; int BLOCK_DIM_y = 32; @@ -353,9 +353,10 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, - bias, biasSize); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, + bias, biasSize); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; int BLOCK_DIM_y = 64; @@ -363,8 +364,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -373,8 +375,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else { int BLOCK_DIM_x = 4; @@ -383,8 +386,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } } @@ -396,8 +400,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, if (dimsize > 1024) { int BLOCK_DIM = 1024; - blockLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize); + blockLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; int BLOCK_DIM_y = 32; @@ -405,8 +410,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; int BLOCK_DIM_y = 64; @@ -414,8 +420,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; int BLOCK_DIM_y = 128; @@ -423,8 +430,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else { int BLOCK_DIM_x = 4; int BLOCK_DIM_y = 256; @@ -432,8 +440,9 @@ void LaynormKernel(const float *input, const float *scale, const float eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } } //----------------- @@ -445,8 +454,8 @@ void LaynormKernel(const half *input, const half *scale, const half eps, int BLOCK_DIM = 1024; blockLaynormKernel - <<>>(input, scale, dimsize, stride, output, - eps, scaleSize, bias, biasSize); + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, bias, biasSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; int BLOCK_DIM_y = 32; @@ -454,8 +463,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; @@ -464,8 +474,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; @@ -474,8 +485,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } else { int BLOCK_DIM_x = 4; @@ -484,8 +496,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block, + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block, bias, biasSize); } } @@ -497,8 +510,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, if (dimsize > 1024) { int BLOCK_DIM = 1024; - blockLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize); + blockLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; int BLOCK_DIM_y = 32; @@ -506,8 +520,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; int BLOCK_DIM_y = 64; @@ -515,8 +530,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; int BLOCK_DIM_y = 128; @@ -524,8 +540,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } else { int BLOCK_DIM_x = 4; int BLOCK_DIM_y = 256; @@ -533,8 +550,9 @@ void LaynormKernel(const half *input, const half *scale, const half eps, dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, 1, 1); - warpLaynormKernel<<>>( - input, scale, dimsize, stride, output, eps, scaleSize, num_block); + warpLaynormKernel + <<>> + (input, scale, dimsize, stride, output, eps, scaleSize, num_block); } } } // namespace infini diff --git a/src/kernels/cuda/pad_slice.cu b/src/kernels/cuda/pad_slice.cu index ccf85748..331f8e0d 100644 --- a/src/kernels/cuda/pad_slice.cu +++ b/src/kernels/cuda/pad_slice.cu @@ -48,8 +48,9 @@ __global__ void _pad_slice_kernel(void *part, void *whole, namespace infini { #define CASE(T) \ - _pad_slice_kernel::t><<>>( \ - partData, wholeData, metadata, nDims, num, isPad); + _pad_slice_kernel::t> \ + <<>> \ + (partData, wholeData, metadata, nDims, num, isPad); #define SWITCH_DTYPE(DTYPE) \ switch (DTYPE) { \ diff --git a/src/kernels/cuda/reshape.cc b/src/kernels/cuda/reshape.cc index 450105b0..bbce222c 100644 --- a/src/kernels/cuda/reshape.cc +++ b/src/kernels/cuda/reshape.cc @@ -7,7 +7,8 @@ class CopyCuda : public CudaKernelWithoutConfig { auto inData = op->getInputs(0)->getRawDataPtr(); auto outData = op->getOutputs()[0]->getRawDataPtr(); cudaMemcpyAsync(outData, inData, op->getInputs(0)->getBytes(), - cudaMemcpyDeviceToDevice); + cudaMemcpyDeviceToDevice, + CUDAStream::getCurrentStream()); } }; // reshape/flatten/identity all act as copying from input to output. diff --git a/src/kernels/cuda/resize.cu b/src/kernels/cuda/resize.cu index 3f985dde..947ee7ce 100644 --- a/src/kernels/cuda/resize.cu +++ b/src/kernels/cuda/resize.cu @@ -213,8 +213,9 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData, sizeof(p_cooridnate_trans_mode_func[0])); IT_ASSERT(nearestMode < sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0])); - _resize_kernel_nearest<<>>( - in, out, metaData, num, coordinateMode, nearestMode); + _resize_kernel_nearest + <<>> + (in, out, metaData, num, coordinateMode, nearestMode); } void resize_kernel_linear(float *in, float *out, const MetaData &metaData, @@ -223,8 +224,9 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData, auto gridsize = (num + blocksize - 1) / blocksize; IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / sizeof(p_cooridnate_trans_mode_func[0])); - _resize_kernel_linear_coeff<<>>(in, out, metaData, num, - coordinateMode); + _resize_kernel_linear_coeff + <<>> + (in, out, metaData, num, coordinateMode); } void resize_kernel_cubic(float *in, float *out, const MetaData &metaData, @@ -233,7 +235,8 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData, auto gridsize = (num + blocksize - 1) / blocksize; IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) / sizeof(p_cooridnate_trans_mode_func[0])); - _resize_kernel_cubic_coeff<<>>(in, out, metaData, num, - coordinateMode); + _resize_kernel_cubic_coeff + <<>> + (in, out, metaData, num, coordinateMode); } } // namespace infini diff --git a/src/kernels/cuda/rope.cu b/src/kernels/cuda/rope.cu index 9b1bec54..8d35026f 100644 --- a/src/kernels/cuda/rope.cu +++ b/src/kernels/cuda/rope.cu @@ -9,7 +9,8 @@ constexpr int block_work_size() { return thread_work_size() * num_threads(); } // gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1) template -__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { +__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, + int dim_head, int hidden_stride, int pos_stride) { int batch_id = blockIdx.x; int target_pos = pos[batch_id * pos_stride + blockIdx.y]; int ith = blockIdx.z * blockDim.x + threadIdx.x; @@ -36,8 +37,9 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo #define CASE(T) \ - _rope_kernel::t><<>>( \ - pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride); + _rope_kernel::t> \ + <<>> \ + (pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride); #define SWITCH_DTYPE(DTYPE) \ switch (DTYPE) { \ @@ -82,7 +84,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo } namespace infini { -void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { +void rope_kernel(int dType, int * pos, void *input, void *output, int size, + int dim_model, int dim_head, int hidden_stride, int pos_stride) { dim3 blocksize = dim3(1024,1,1); dim3 gridsize = dim3(1, 1, 4); SWITCH_DTYPE(dType) diff --git a/src/kernels/cuda/softmax.cu b/src/kernels/cuda/softmax.cu index 69334d50..2fe2d8a5 100644 --- a/src/kernels/cuda/softmax.cu +++ b/src/kernels/cuda/softmax.cu @@ -246,32 +246,38 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 64) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 32) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 16) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 4) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; int BLOCK_DIM_y = 32; @@ -280,7 +286,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; int BLOCK_DIM_y = 64; @@ -289,7 +296,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; int BLOCK_DIM_y = 128; @@ -298,7 +306,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else { int BLOCK_DIM_x = 4; int BLOCK_DIM_y = 256; @@ -307,7 +316,8 @@ void softmax_kernel(int num_blocks, float *input, float *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } } //------------------ @@ -318,32 +328,38 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 64) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 32) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 16) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024 * 4) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 1024) { int BLOCK_DIM = 1024; _blockSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 31) { int BLOCK_DIM_x = 32; int BLOCK_DIM_y = 32; @@ -352,7 +368,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 15) { int BLOCK_DIM_x = 16; int BLOCK_DIM_y = 64; @@ -361,7 +378,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else if (dimsize > 7) { int BLOCK_DIM_x = 8; int BLOCK_DIM_y = 128; @@ -370,7 +388,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } else { int BLOCK_DIM_x = 4; int BLOCK_DIM_y = 256; @@ -379,7 +398,8 @@ void softmax_kernel(int num_blocks, half *input, half *output, int size, dim3 grid_dim(num_block_x, 1, 1); _warpSoftmaxKernel - <<>>(input, output, size, dimsize, stride); + <<>> + (input, output, size, dimsize, stride); } } } // namespace infini diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc index e06ef731..df9dadfe 100644 --- a/src/kernels/cuda/split_concat.cc +++ b/src/kernels/cuda/split_concat.cc @@ -70,7 +70,8 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { _op->getOutputs()[0]->getRawDataPtr(); cudaMemcpyAsync(outData, inData, _op->getInputs(1 - i)->getBytes(), - cudaMemcpyDeviceToDevice); + cudaMemcpyDeviceToDevice, + CUDAStream::getCurrentStream()); return; } } diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu index fdb5f18c..f14dc973 100644 --- a/src/kernels/cuda/split_concat.cu +++ b/src/kernels/cuda/split_concat.cu @@ -63,8 +63,9 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta, // each y is a split among the batch dim3 gridSize(gridDimX, batchSize); - _split_concat_kernel<<>>(eleMeta, compMeta, dim, nDims, - isSplit); + _split_concat_kernel + <<>> + (eleMeta, compMeta, dim, nDims, isSplit); } void split_concat_kernel(const ElementTensorMetadata &eleMeta, const ComposedTensorMetadata &compMeta, int dim, @@ -77,8 +78,9 @@ void split_concat_kernel(const ElementTensorMetadata &eleMeta, // each y is a split among the batch dim3 gridSize(gridDimX, batchSize); - _split_concat_kernel<<>>(eleMeta, compMeta, dim, nDims, - isSplit); + _split_concat_kernel + <<>> + (eleMeta, compMeta, dim, nDims, isSplit); } } // namespace infini diff --git a/src/kernels/cuda/transpose.cu b/src/kernels/cuda/transpose.cu index 917afde3..833c1154 100644 --- a/src/kernels/cuda/transpose.cu +++ b/src/kernels/cuda/transpose.cu @@ -23,8 +23,9 @@ __global__ void _transpose_kernel(void *input, void *output, int nDims, } } #define CASE(T) \ - _transpose_kernel::t><<>>( \ - input, output, nDims, size, strides, outputShape); + _transpose_kernel::t> \ + <<>> \ + (input, output, nDims, size, strides, outputShape); #define SWITCH_DTYPE(DTYPE) \ switch (DTYPE) { \ diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 98f1ed9f..93a3cf6c 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -148,78 +148,104 @@ template void softmax_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _softmax_kernel1<<<1, 1>>>(input, output, num); - _softmax_kernel2<<>>(input, output, num); + _softmax_kernel1 + <<<1, 1, 0, CUDAStream::getCurrentStream()>>> + (input, output, num); + _softmax_kernel2 + <<>> + (input, output, num); } template void relu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _relu_kernel<<>>(input, output, num); + _relu_kernel + <<>> + (input, output, num); } template void sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _sigmoid_kernel<<>>(input, output, num); + _sigmoid_kernel + <<>> + (input, output, num); } template void hard_sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _hard_sigmoid_kernel<<>>(input, output, num); + _hard_sigmoid_kernel + <<>> + (input, output, num); } template void hard_swish_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _hard_swish_kernel<<>>(input, output, num); + _hard_swish_kernel + <<>> + (input, output, num); } template void tanh_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _tanh_kernel<<>>(input, output, num); + _tanh_kernel + <<>> + (input, output, num); } template void abs_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _abs_kernel<<>>(input, output, num); + _abs_kernel + <<>> + (input, output, num); } template void sqrt_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _sqrt_kernel<<>>((T *)input, (T *)output, num); + _sqrt_kernel + <<>> + ((T *)input, (T *)output, num); } template void gelu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _gelu_kernel<<>>(input, output, num); + _gelu_kernel + <<>> + (input, output, num); } template void silu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _silu_kernel<<>>(input, output, num); + _silu_kernel + <<>> + (input, output, num); } template void erf_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _erf_kernel<<>>(input, output, num); + _erf_kernel + <<>> + (input, output, num); } template void neg_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _neg_kernel<<>>(input, output, num); + _neg_kernel + <<>> + (input, output, num); } void unary_kernel(const Operator &_op) { @@ -317,7 +343,9 @@ void cast_kernel(INPUT *input, OUTPUT *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _cast_kernel<<>>(input, output, num); + _cast_kernel + <<>> + (input, output, num); } template void cast_kernel(float *input, half *output, size_t num); diff --git a/src/kernels/cuda/where.cu b/src/kernels/cuda/where.cu index e92a5e9f..f3c2459f 100644 --- a/src/kernels/cuda/where.cu +++ b/src/kernels/cuda/where.cu @@ -61,7 +61,8 @@ void whereKernel(const float *inputX, const float *inputY, blocksize = 32; } int gridsize = (outputsize + blocksize - 1) / blocksize; - _whereKernel<<>>( + _whereKernel + <<>>( inputX, inputY, condition, output, nDims, outputsize, inputXShape, inputYShape, conditionShape, outputShape, xSize, ySize, cSize); } @@ -85,7 +86,8 @@ void whereKernel(const half *inputX, const half *inputY, blocksize = 32; } int gridsize = (outputsize + blocksize - 1) / blocksize; - _whereKernel<<>>( + _whereKernel + <<>>( inputX, inputY, condition, output, nDims, outputsize, inputXShape, inputYShape, conditionShape, outputShape, xSize, ySize, cSize); } diff --git a/test/cuda/test_cudagraph.cc b/test/cuda/test_cudagraph.cc new file mode 100644 index 00000000..8f2ac3d0 --- /dev/null +++ b/test/cuda/test_cudagraph.cc @@ -0,0 +1,70 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/attention_kvcache.h" + +#include "test.h" + +namespace infini { + +TEST(TestCudaRuntime, CudaGraph) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + + Graph gCpu = make_ref(runtime); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); + + auto op = gCuda->addOp( + input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d, + position_id_d, nullptr); + auto op1 = gCuda->addOp( + input_k_cache_d, input_v_cache_d, op->getOutputs()[0], input_k_d, + input_v_d, position_id_d, nullptr); + auto op2 = gCuda->addOp( + input_k_cache_d, input_v_cache_d, op1->getOutputs()[0], input_k_d, + input_v_d, position_id_d, nullptr); + gCuda->dataMalloc(); + + input_q_d->setData(OneGenerator()); + input_k_d->setData(OneGenerator()); + input_v_d->setData(OneGenerator()); + position_id_d->setData(IncrementalGenerator()); + + cudaRuntime->run(gCuda); + + cudaEvent_t start, stop; + float milliseconds_1 = 0, milliseconds_2 = 0; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaDeviceSynchronize(); + cudaEventRecord(start); + cudaRuntime->run(gCuda); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + cudaEventElapsedTime(&milliseconds_1, start, stop); + printf("without cudaGraph, latency: %f ms\n", milliseconds_1); + + cudaRuntime->runWithCudaGraph(gCuda); + cudaRuntime->runWithCudaGraph(gCuda); + + cudaDeviceSynchronize(); + cudaEventRecord(start); + cudaRuntime->runWithCudaGraph(gCuda); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + cudaEventElapsedTime(&milliseconds_2, start, stop); + printf("with cudaGraph, latency: %f ms\n", milliseconds_2); + EXPECT_GE(milliseconds_1, milliseconds_2); +} + +} // namespace infini