modified attention.cu,BLOCK_DIM_x must leq 32

This commit is contained in:
xgqdut2016 2023-09-26 14:53:02 +08:00
parent ec391674ac
commit b640ab1689
1 changed files with 84 additions and 50 deletions

View File

@ -1,7 +1,7 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#define BLOCK_DIM_x 2 #define BLOCK_DIM_x 8 // BLOCK_DIM_x must <= 32
#define BLOCK_DIM_y 2 #define BLOCK_DIM_y 128
#define max_function(a, b) ((a) > (b) ? (a) : (b)) #define max_function(a, b) ((a) > (b) ? (a) : (b))
__global__ void _attentionKernel(const float *inputQ, const float *inputK, __global__ void _attentionKernel(const float *inputQ, const float *inputK,
@ -10,84 +10,117 @@ __global__ void _attentionKernel(const float *inputQ, const float *inputK,
int i = blockIdx.x; // i must < N,Q[i] int i = blockIdx.x; // i must < N,Q[i]
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d] int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x; int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
__shared__ float old_max; __shared__ float old_max[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float new_max; __shared__ float new_max[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float new_sum; __shared__ float new_sum[BLOCK_DIM_x][BLOCK_DIM_y];
old_max = -__FLT_MAX__; old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
new_max = -__FLT_MAX__; new_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
new_sum = 0.0f; new_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float block_sum[BLOCK_DIM_x]; __shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
__shared__ float block_max[BLOCK_DIM_x]; __shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
block_max[threadIdx.x] = -__FLT_MAX__; block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x] = 0.0f; block_sum[threadIdx.x][threadIdx.y] = 0.0f;
__shared__ float inputS[BLOCK_DIM_x]; __shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y];
output[i * d + phd] = 0.0f; __syncthreads();
for (int phn = 0; phn < phNumN; phn++) { for (int phn = 0; phn < phNumN; phn++) {
int j = threadIdx.x + phn * BLOCK_DIM_x; int j = threadIdx.x + phn * BLOCK_DIM_x;
if (j < N) { inputS[threadIdx.x][threadIdx.y] = 0.0f;
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
if (j < N && phd < d) {
float sum_s = 0; float sum_s = 0;
for (int index = 0; index < d; index++) { for (int index = 0; index < d; index++) {
sum_s += inputQ[i * d + index] * inputK[j * d + index]; sum_s += inputQ[i * d + index] * inputK[j * d + index];
} }
inputS[threadIdx.x] = sum_s; inputS[threadIdx.x][threadIdx.y] = sum_s;
block_max[threadIdx.x] = sum_s; block_max[threadIdx.x][threadIdx.y] = sum_s;
block_sum[threadIdx.x] = 1.0f; block_sum[threadIdx.x][threadIdx.y] = 1.0f;
} else {
inputS[threadIdx.x] = 0.0f;
block_max[threadIdx.x] = -__FLT_MAX__;
block_sum[threadIdx.x] = 0.0f;
} }
__syncthreads(); __syncthreads();
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) { for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) {
if (threadIdx.x < strip) { if (threadIdx.x < strip) {
if (block_max[threadIdx.x] > block_max[threadIdx.x + strip]) { if (block_max[threadIdx.x][threadIdx.y] >
block_sum[threadIdx.x] = block_max[threadIdx.x + strip][threadIdx.y]) {
block_sum[threadIdx.x] + block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x + strip] * block_sum[threadIdx.x][threadIdx.y] +
__expf(block_max[threadIdx.x + strip] - block_sum[threadIdx.x + strip][threadIdx.y] *
block_max[threadIdx.x]); __expf(block_max[threadIdx.x + strip][threadIdx.y] -
block_max[threadIdx.x][threadIdx.y]);
} else { } else {
block_sum[threadIdx.x] = block_sum[threadIdx.x][threadIdx.y] =
block_sum[threadIdx.x + strip] + block_sum[threadIdx.x + strip][threadIdx.y] +
block_sum[threadIdx.x] * block_sum[threadIdx.x][threadIdx.y] *
__expf(block_max[threadIdx.x] - __expf(block_max[threadIdx.x][threadIdx.y] -
block_max[threadIdx.x + strip]); block_max[threadIdx.x + strip][threadIdx.y]);
block_max[threadIdx.x] = block_max[threadIdx.x + strip]; block_max[threadIdx.x][threadIdx.y] =
block_max[threadIdx.x + strip][threadIdx.y];
} }
} }
__syncthreads();
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (j < N && phd < d) {
if (new_max > block_max[0]) { if (new_max[threadIdx.x][threadIdx.y] > block_max[0][threadIdx.y]) {
new_sum = new_sum[threadIdx.x][threadIdx.y] =
new_sum + block_sum[0] * __expf(block_max[0] - new_max); new_sum[threadIdx.x][threadIdx.y] +
block_sum[0][threadIdx.y] *
__expf(block_max[0][threadIdx.y] -
new_max[threadIdx.x][threadIdx.y]);
} else { } else {
new_sum = new_sum[threadIdx.x][threadIdx.y] =
block_sum[0] + new_sum * __expf(new_max - block_max[0]); block_sum[0][threadIdx.y] +
new_max = block_max[0]; new_sum[threadIdx.x][threadIdx.y] *
__expf(new_max[threadIdx.x][threadIdx.y] -
block_max[0][threadIdx.y]);
new_max[threadIdx.x][threadIdx.y] = block_max[0][threadIdx.y];
} }
} }
__syncthreads(); __syncthreads();
inputS[threadIdx.x] = __expf(inputS[threadIdx.x] - new_max);
if (j < N && phd < d) {
inputS[threadIdx.x][threadIdx.y] =
__expf(inputS[threadIdx.x][threadIdx.y] -
new_max[threadIdx.x][threadIdx.y]);
} else {
inputS[threadIdx.x][threadIdx.y] = 0.0f;
}
__syncthreads(); __syncthreads();
float sum_o = 0;
if (phd < d) { if (phd < d) {
float sum_o = 0.0f;
for (int index = 0; index < BLOCK_DIM_x; index++) { for (int index = 0; index < BLOCK_DIM_x; index++) {
if (index + phn * BLOCK_DIM_x < N) { if (index + phn * BLOCK_DIM_x < N) {
sum_o += inputS[index] * sum_o += inputS[index][threadIdx.y] *
inputV[(index + phn * BLOCK_DIM_x) * d + phd]; inputV[(index + phn * BLOCK_DIM_x) * d + phd];
} }
} }
output[i * d + phd] = if (phn == 0) {
__expf(old_max - new_max) * output[i * d + phd] + sum_o; output[i * d + phd] = sum_o;
old_max = new_max; } else {
output[i * d + phd] =
__expf(old_max[threadIdx.x][threadIdx.y] -
new_max[threadIdx.x][threadIdx.y]) *
output[i * d + phd] +
sum_o;
}
old_max[threadIdx.x][threadIdx.y] =
new_max[threadIdx.x][threadIdx.y];
} else {
old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
} }
//__syncthreads(); __syncthreads();
} }
__syncthreads();
if (phd < d) if (phd < d)
output[i * d + phd] = output[i * d + phd] * __fdividef(1.0F, new_sum); output[i * d + phd] =
output[i * d + phd] *
__fdividef(1.0F, new_sum[threadIdx.x][threadIdx.y]);
} }
namespace infini { namespace infini {
void attentionKernel(const float *inputQ, const float *inputK, void attentionKernel(const float *inputQ, const float *inputK,
@ -97,7 +130,8 @@ void attentionKernel(const float *inputQ, const float *inputK,
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y; int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1); dim3 grid_dim(num_block_x, num_block_y, 1);
int share_mem = (3 * BLOCK_DIM_x + 3) * sizeof(float); int share_mem =
(3 * BLOCK_DIM_x + 3 * BLOCK_DIM_x) * BLOCK_DIM_y * sizeof(float);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV, _attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
N, d, output); N, d, output);
} }