modified attention.cu,BLOCK_DIM_x must leq 32

This commit is contained in:
xgqdut2016 2023-09-26 14:53:02 +08:00
parent ec391674ac
commit b640ab1689
1 changed files with 84 additions and 50 deletions

View File

@ -1,7 +1,7 @@
#include "cuda/cuda_common.h"
#define BLOCK_DIM_x 2
#define BLOCK_DIM_y 2
#define BLOCK_DIM_x 8 // BLOCK_DIM_x must <= 32
#define BLOCK_DIM_y 128
#define max_function(a, b) ((a) > (b) ? (a) : (b))
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
@ -10,84 +10,117 @@ __global__ void _attentionKernel(const float *inputQ, const float *inputK,
int i = blockIdx.x; // i must < N,Q[i]
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
__shared__ float old_max;
__shared__ float new_max;
__shared__ float new_sum;
old_max = -__FLT_MAX__;
new_max = -__FLT_MAX__;
new_sum = 0.0f;
__shared__ float block_sum[BLOCK_DIM_x];
__shared__ float block_max[BLOCK_DIM_x];
block_max[threadIdx.x] = -__FLT_MAX__;
block_sum[threadIdx.x] = 0.0f;
__shared__ float old_max[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float new_max[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float new_sum[BLOCK_DIM_x][BLOCK_DIM_y];
old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
new_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
new_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float inputS[BLOCK_DIM_x];
__shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y];
output[i * d + phd] = 0.0f;
__syncthreads();
for (int phn = 0; phn < phNumN; phn++) {
int j = threadIdx.x + phn * BLOCK_DIM_x;
if (j < N) {
inputS[threadIdx.x][threadIdx.y] = 0.0f;
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
if (j < N && phd < d) {
float sum_s = 0;
for (int index = 0; index < d; index++) {
sum_s += inputQ[i * d + index] * inputK[j * d + index];
}
inputS[threadIdx.x] = sum_s;
block_max[threadIdx.x] = sum_s;
block_sum[threadIdx.x] = 1.0f;
} else {
inputS[threadIdx.x] = 0.0f;
block_max[threadIdx.x] = -__FLT_MAX__;
block_sum[threadIdx.x] = 0.0f;
inputS[threadIdx.x][threadIdx.y] = sum_s;
block_max[threadIdx.x][threadIdx.y] = sum_s;
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
}
__syncthreads();
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) {
if (threadIdx.x < strip) {
if (block_max[threadIdx.x] > block_max[threadIdx.x + strip]) {
block_sum[threadIdx.x] =
block_sum[threadIdx.x] +
block_sum[threadIdx.x + strip] *
__expf(block_max[threadIdx.x + strip] -
block_max[threadIdx.x]);
if (block_max[threadIdx.x][threadIdx.y] >
block_max[threadIdx.x + strip][threadIdx.y]) {
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x][threadIdx.y] +
block_sum[threadIdx.x + strip][threadIdx.y] *
__expf(block_max[threadIdx.x + strip][threadIdx.y] -
block_max[threadIdx.x][threadIdx.y]);
} else {
block_sum[threadIdx.x] =
block_sum[threadIdx.x + strip] +
block_sum[threadIdx.x] *
__expf(block_max[threadIdx.x] -
block_max[threadIdx.x + strip]);
block_max[threadIdx.x] = block_max[threadIdx.x + strip];
block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x + strip][threadIdx.y] +
block_sum[threadIdx.x][threadIdx.y] *
__expf(block_max[threadIdx.x][threadIdx.y] -
block_max[threadIdx.x + strip][threadIdx.y]);
block_max[threadIdx.x][threadIdx.y] =
block_max[threadIdx.x + strip][threadIdx.y];
}
}
__syncthreads();
}
__syncthreads();
if (threadIdx.x == 0) {
if (new_max > block_max[0]) {
new_sum =
new_sum + block_sum[0] * __expf(block_max[0] - new_max);
if (j < N && phd < d) {
if (new_max[threadIdx.x][threadIdx.y] > block_max[0][threadIdx.y]) {
new_sum[threadIdx.x][threadIdx.y] =
new_sum[threadIdx.x][threadIdx.y] +
block_sum[0][threadIdx.y] *
__expf(block_max[0][threadIdx.y] -
new_max[threadIdx.x][threadIdx.y]);
} else {
new_sum =
block_sum[0] + new_sum * __expf(new_max - block_max[0]);
new_max = block_max[0];
new_sum[threadIdx.x][threadIdx.y] =
block_sum[0][threadIdx.y] +
new_sum[threadIdx.x][threadIdx.y] *
__expf(new_max[threadIdx.x][threadIdx.y] -
block_max[0][threadIdx.y]);
new_max[threadIdx.x][threadIdx.y] = block_max[0][threadIdx.y];
}
}
__syncthreads();
inputS[threadIdx.x] = __expf(inputS[threadIdx.x] - new_max);
if (j < N && phd < d) {
inputS[threadIdx.x][threadIdx.y] =
__expf(inputS[threadIdx.x][threadIdx.y] -
new_max[threadIdx.x][threadIdx.y]);
} else {
inputS[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads();
float sum_o = 0;
if (phd < d) {
float sum_o = 0.0f;
for (int index = 0; index < BLOCK_DIM_x; index++) {
if (index + phn * BLOCK_DIM_x < N) {
sum_o += inputS[index] *
sum_o += inputS[index][threadIdx.y] *
inputV[(index + phn * BLOCK_DIM_x) * d + phd];
}
}
output[i * d + phd] =
__expf(old_max - new_max) * output[i * d + phd] + sum_o;
old_max = new_max;
if (phn == 0) {
output[i * d + phd] = sum_o;
} else {
output[i * d + phd] =
__expf(old_max[threadIdx.x][threadIdx.y] -
new_max[threadIdx.x][threadIdx.y]) *
output[i * d + phd] +
sum_o;
}
old_max[threadIdx.x][threadIdx.y] =
new_max[threadIdx.x][threadIdx.y];
} else {
old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
}
//__syncthreads();
__syncthreads();
}
__syncthreads();
if (phd < d)
output[i * d + phd] = output[i * d + phd] * __fdividef(1.0F, new_sum);
output[i * d + phd] =
output[i * d + phd] *
__fdividef(1.0F, new_sum[threadIdx.x][threadIdx.y]);
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
@ -97,7 +130,8 @@ void attentionKernel(const float *inputQ, const float *inputK,
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
int share_mem = (3 * BLOCK_DIM_x + 3) * sizeof(float);
int share_mem =
(3 * BLOCK_DIM_x + 3 * BLOCK_DIM_x) * BLOCK_DIM_y * sizeof(float);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
N, d, output);
}