From b640ab1689e4a9066ae7ea2e620fd5fa0f63ad2d Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 26 Sep 2023 14:53:02 +0800 Subject: [PATCH] modified attention.cu,BLOCK_DIM_x must leq 32 --- src/kernels/cuda/attention.cu | 134 +++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 50 deletions(-) diff --git a/src/kernels/cuda/attention.cu b/src/kernels/cuda/attention.cu index d66a87c9..be958522 100644 --- a/src/kernels/cuda/attention.cu +++ b/src/kernels/cuda/attention.cu @@ -1,7 +1,7 @@ #include "cuda/cuda_common.h" -#define BLOCK_DIM_x 2 -#define BLOCK_DIM_y 2 +#define BLOCK_DIM_x 8 // BLOCK_DIM_x must <= 32 +#define BLOCK_DIM_y 128 #define max_function(a, b) ((a) > (b) ? (a) : (b)) __global__ void _attentionKernel(const float *inputQ, const float *inputK, @@ -10,84 +10,117 @@ __global__ void _attentionKernel(const float *inputQ, const float *inputK, int i = blockIdx.x; // i must < N,Q[i] int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d] int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x; - __shared__ float old_max; - __shared__ float new_max; - __shared__ float new_sum; - old_max = -__FLT_MAX__; - new_max = -__FLT_MAX__; - new_sum = 0.0f; - __shared__ float block_sum[BLOCK_DIM_x]; - __shared__ float block_max[BLOCK_DIM_x]; - block_max[threadIdx.x] = -__FLT_MAX__; - block_sum[threadIdx.x] = 0.0f; + __shared__ float old_max[BLOCK_DIM_x][BLOCK_DIM_y]; + __shared__ float new_max[BLOCK_DIM_x][BLOCK_DIM_y]; + __shared__ float new_sum[BLOCK_DIM_x][BLOCK_DIM_y]; + old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; + new_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; + new_sum[threadIdx.x][threadIdx.y] = 0.0f; + __shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y]; + __shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y]; + block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; + block_sum[threadIdx.x][threadIdx.y] = 0.0f; - __shared__ float inputS[BLOCK_DIM_x]; + __shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y]; - output[i * d + phd] = 0.0f; + __syncthreads(); for (int phn = 0; phn < phNumN; phn++) { int j = threadIdx.x + phn * BLOCK_DIM_x; - if (j < N) { + inputS[threadIdx.x][threadIdx.y] = 0.0f; + block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; + block_sum[threadIdx.x][threadIdx.y] = 0.0f; + + if (j < N && phd < d) { float sum_s = 0; for (int index = 0; index < d; index++) { sum_s += inputQ[i * d + index] * inputK[j * d + index]; } - inputS[threadIdx.x] = sum_s; - block_max[threadIdx.x] = sum_s; - block_sum[threadIdx.x] = 1.0f; - } else { - inputS[threadIdx.x] = 0.0f; - block_max[threadIdx.x] = -__FLT_MAX__; - block_sum[threadIdx.x] = 0.0f; + inputS[threadIdx.x][threadIdx.y] = sum_s; + block_max[threadIdx.x][threadIdx.y] = sum_s; + block_sum[threadIdx.x][threadIdx.y] = 1.0f; } + __syncthreads(); for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) { if (threadIdx.x < strip) { - if (block_max[threadIdx.x] > block_max[threadIdx.x + strip]) { - block_sum[threadIdx.x] = - block_sum[threadIdx.x] + - block_sum[threadIdx.x + strip] * - __expf(block_max[threadIdx.x + strip] - - block_max[threadIdx.x]); + if (block_max[threadIdx.x][threadIdx.y] > + block_max[threadIdx.x + strip][threadIdx.y]) { + block_sum[threadIdx.x][threadIdx.y] = + block_sum[threadIdx.x][threadIdx.y] + + block_sum[threadIdx.x + strip][threadIdx.y] * + __expf(block_max[threadIdx.x + strip][threadIdx.y] - + block_max[threadIdx.x][threadIdx.y]); } else { - block_sum[threadIdx.x] = - block_sum[threadIdx.x + strip] + - block_sum[threadIdx.x] * - __expf(block_max[threadIdx.x] - - block_max[threadIdx.x + strip]); - block_max[threadIdx.x] = block_max[threadIdx.x + strip]; + block_sum[threadIdx.x][threadIdx.y] = + block_sum[threadIdx.x + strip][threadIdx.y] + + block_sum[threadIdx.x][threadIdx.y] * + __expf(block_max[threadIdx.x][threadIdx.y] - + block_max[threadIdx.x + strip][threadIdx.y]); + block_max[threadIdx.x][threadIdx.y] = + block_max[threadIdx.x + strip][threadIdx.y]; } } + __syncthreads(); } __syncthreads(); - if (threadIdx.x == 0) { - if (new_max > block_max[0]) { - new_sum = - new_sum + block_sum[0] * __expf(block_max[0] - new_max); + if (j < N && phd < d) { + if (new_max[threadIdx.x][threadIdx.y] > block_max[0][threadIdx.y]) { + new_sum[threadIdx.x][threadIdx.y] = + new_sum[threadIdx.x][threadIdx.y] + + block_sum[0][threadIdx.y] * + __expf(block_max[0][threadIdx.y] - + new_max[threadIdx.x][threadIdx.y]); } else { - new_sum = - block_sum[0] + new_sum * __expf(new_max - block_max[0]); - new_max = block_max[0]; + new_sum[threadIdx.x][threadIdx.y] = + block_sum[0][threadIdx.y] + + new_sum[threadIdx.x][threadIdx.y] * + __expf(new_max[threadIdx.x][threadIdx.y] - + block_max[0][threadIdx.y]); + new_max[threadIdx.x][threadIdx.y] = block_max[0][threadIdx.y]; } } + __syncthreads(); - inputS[threadIdx.x] = __expf(inputS[threadIdx.x] - new_max); + + if (j < N && phd < d) { + inputS[threadIdx.x][threadIdx.y] = + __expf(inputS[threadIdx.x][threadIdx.y] - + new_max[threadIdx.x][threadIdx.y]); + } else { + inputS[threadIdx.x][threadIdx.y] = 0.0f; + } __syncthreads(); - float sum_o = 0; + if (phd < d) { + float sum_o = 0.0f; for (int index = 0; index < BLOCK_DIM_x; index++) { if (index + phn * BLOCK_DIM_x < N) { - sum_o += inputS[index] * + sum_o += inputS[index][threadIdx.y] * inputV[(index + phn * BLOCK_DIM_x) * d + phd]; } } - output[i * d + phd] = - __expf(old_max - new_max) * output[i * d + phd] + sum_o; - old_max = new_max; + if (phn == 0) { + output[i * d + phd] = sum_o; + } else { + output[i * d + phd] = + __expf(old_max[threadIdx.x][threadIdx.y] - + new_max[threadIdx.x][threadIdx.y]) * + output[i * d + phd] + + sum_o; + } + + old_max[threadIdx.x][threadIdx.y] = + new_max[threadIdx.x][threadIdx.y]; + } else { + old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__; } - //__syncthreads(); + __syncthreads(); } + __syncthreads(); if (phd < d) - output[i * d + phd] = output[i * d + phd] * __fdividef(1.0F, new_sum); + output[i * d + phd] = + output[i * d + phd] * + __fdividef(1.0F, new_sum[threadIdx.x][threadIdx.y]); } namespace infini { void attentionKernel(const float *inputQ, const float *inputK, @@ -97,7 +130,8 @@ void attentionKernel(const float *inputQ, const float *inputK, int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, num_block_y, 1); - int share_mem = (3 * BLOCK_DIM_x + 3) * sizeof(float); + int share_mem = + (3 * BLOCK_DIM_x + 3 * BLOCK_DIM_x) * BLOCK_DIM_y * sizeof(float); _attentionKernel<<>>(inputQ, inputK, inputV, N, d, output); }