From 56e2c87c9bbb9f880beb8e4c178a1d7843375caf Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Sat, 7 Oct 2023 18:14:12 +0800 Subject: [PATCH] modified reduce,8ms --- src/kernels/cuda/attention.cu | 137 ++++++++++++---------------------- 1 file changed, 48 insertions(+), 89 deletions(-) diff --git a/src/kernels/cuda/attention.cu b/src/kernels/cuda/attention.cu index 6ff9f95f..4fe1cb65 100644 --- a/src/kernels/cuda/attention.cu +++ b/src/kernels/cuda/attention.cu @@ -11,88 +11,75 @@ __launch_bounds__(BLOCK_DIM_y) __global__ int i = blockIdx.x; // i must < N,Q[i] int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d] - __shared__ float old_max[BLOCK_DIM_y]; - __shared__ float new_max[BLOCK_DIM_y]; - __shared__ float new_sum[BLOCK_DIM_y]; - old_max[threadIdx.y] = -__FLT_MAX__; - new_max[threadIdx.y] = -__FLT_MAX__; - new_sum[threadIdx.y] = 0.0f; + float old_max = -__FLT_MAX__; + float new_max = -__FLT_MAX__; + float new_sum = 0.0f; - __shared__ float shareV[BLOCK_DIM_y]; __shared__ float out[BLOCK_DIM_y]; - int phNumD = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + int extra = d % BLOCK_DIM_y; + int step = (d - extra) / BLOCK_DIM_y; __shared__ float shareQ_times_K[BLOCK_DIM_y]; for (int phn = 0; phn < N; phn++) { - shareV[threadIdx.y] = 0.0f; - + shareQ_times_K[threadIdx.y] = 0.0f; float sum_s = 0.0f; - for (int ind = 0; ind < phNumD; ind++) { - if (threadIdx.y + ind * BLOCK_DIM_y < d) { - shareQ_times_K[threadIdx.y] = - inputQ[i * d + threadIdx.y + ind * BLOCK_DIM_y] * - inputK[phn * d + threadIdx.y + ind * BLOCK_DIM_y]; - - } else { - shareQ_times_K[threadIdx.y] = 0.0f; + if (threadIdx.y < extra) { + for (int ind = threadIdx.y * (step + 1); + ind < (threadIdx.y + 1) * (step + 1); ind++) { + shareQ_times_K[threadIdx.y] += + inputQ[i * d + ind] * inputK[phn * d + ind]; } - __syncthreads(); - for (int strip = BLOCK_DIM_y / 2; strip > 0; strip = strip / 2) { - if (threadIdx.y < strip) { - shareQ_times_K[threadIdx.y] += - shareQ_times_K[threadIdx.y + strip]; - } - __syncthreads(); - } - sum_s += shareQ_times_K[0]; - __syncthreads(); - } - - shareQ_times_K[threadIdx.y] = sum_s; - - if (phd < d) { - shareV[threadIdx.y] = inputV[phn * d + phd]; - } - - __syncthreads(); - - if (new_max[threadIdx.y] > sum_s) { - new_sum[threadIdx.y] = - new_sum[threadIdx.y] + __expf(sum_s - new_max[threadIdx.y]); } else { - new_sum[threadIdx.y] = - 1.0f + - new_sum[threadIdx.y] * __expf(new_max[threadIdx.y] - sum_s); - new_max[threadIdx.y] = sum_s; + for (int ind = extra * (step + 1) + (threadIdx.y - extra) * step; + ind < extra * (step + 1) + (threadIdx.y - extra + 1) * step; + ind++) { + shareQ_times_K[threadIdx.y] += + inputQ[i * d + ind] * inputK[phn * d + ind]; + } } __syncthreads(); + for (int strip = BLOCK_DIM_y / 8; strip > 0; strip = strip / 8) { + if (threadIdx.y < strip) { + for (int id = 1; id < 8; id++) { + shareQ_times_K[threadIdx.y] += + shareQ_times_K[threadIdx.y + id * strip]; + } + } + __syncthreads(); + } + sum_s = shareQ_times_K[0] + shareQ_times_K[1]; + //__syncthreads(); - shareQ_times_K[threadIdx.y] = - __expf(shareQ_times_K[threadIdx.y] - new_max[threadIdx.y]); + if (new_max > sum_s) { + new_sum = new_sum + __expf(sum_s - new_max); + } else { + new_sum = 1.0f + new_sum * __expf(new_max - sum_s); + new_max = sum_s; + } - __syncthreads(); + //__syncthreads(); + + sum_s = __expf(sum_s - new_max); + + //__syncthreads(); if (phn == 0) { - out[threadIdx.y] = - shareQ_times_K[threadIdx.y] * shareV[threadIdx.y]; + out[threadIdx.y] = sum_s * inputV[phn * d + phd]; } else { - out[threadIdx.y] = - __expf(old_max[threadIdx.y] - new_max[threadIdx.y]) * - out[threadIdx.y] + - shareQ_times_K[threadIdx.y] * shareV[threadIdx.y]; + out[threadIdx.y] = __expf(old_max - new_max) * out[threadIdx.y] + + sum_s * inputV[phn * d + phd]; } - old_max[threadIdx.y] = new_max[threadIdx.y]; + old_max = new_max; - __syncthreads(); + //__syncthreads(); } - __syncthreads(); + //__syncthreads(); if (phd < d) - output[i * d + phd] = - out[threadIdx.y] * __fdividef(1.0F, new_sum[threadIdx.y]); + output[i * d + phd] = out[threadIdx.y] * __fdividef(1.0F, new_sum); } namespace infini { void attentionKernel(const float *inputQ, const float *inputK, @@ -100,48 +87,20 @@ void attentionKernel(const float *inputQ, const float *inputK, int num_block_x = N; - if (d > 1023) { + if (d > 128) { int BLOCK_DIM_y = 1024; int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; dim3 block_dim(1, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, num_block_y, 1); _attentionKernel<1024> <<>>(inputQ, inputK, inputV, N, d, output); - } else if (d > 511) { - int BLOCK_DIM_y = 512; - int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; - dim3 block_dim(1, BLOCK_DIM_y, 1); - dim3 grid_dim(num_block_x, num_block_y, 1); - _attentionKernel<512> - <<>>(inputQ, inputK, inputV, N, d, output); - } else if (d > 255) { - int BLOCK_DIM_y = 256; - int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; - dim3 block_dim(1, BLOCK_DIM_y, 1); - dim3 grid_dim(num_block_x, num_block_y, 1); - _attentionKernel<256> - <<>>(inputQ, inputK, inputV, N, d, output); - } else if (d > 127) { + } else if (d > 16) { int BLOCK_DIM_y = 128; int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; dim3 block_dim(1, BLOCK_DIM_y, 1); dim3 grid_dim(num_block_x, num_block_y, 1); _attentionKernel<128> <<>>(inputQ, inputK, inputV, N, d, output); - } else if (d > 63) { - int BLOCK_DIM_y = 64; - int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; - dim3 block_dim(1, BLOCK_DIM_y, 1); - dim3 grid_dim(num_block_x, num_block_y, 1); - _attentionKernel<64> - <<>>(inputQ, inputK, inputV, N, d, output); - } else if (d > 31) { - int BLOCK_DIM_y = 32; - int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; - dim3 block_dim(1, BLOCK_DIM_y, 1); - dim3 grid_dim(num_block_x, num_block_y, 1); - _attentionKernel<32> - <<>>(inputQ, inputK, inputV, N, d, output); } else { int BLOCK_DIM_y = 16; int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;