modified reduce,8ms

This commit is contained in:
xgqdut2016 2023-10-07 18:14:12 +08:00
parent 819484eda2
commit 56e2c87c9b
1 changed files with 48 additions and 89 deletions

View File

@ -11,88 +11,75 @@ __launch_bounds__(BLOCK_DIM_y) __global__
int i = blockIdx.x; // i must < N,Q[i]
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
__shared__ float old_max[BLOCK_DIM_y];
__shared__ float new_max[BLOCK_DIM_y];
__shared__ float new_sum[BLOCK_DIM_y];
old_max[threadIdx.y] = -__FLT_MAX__;
new_max[threadIdx.y] = -__FLT_MAX__;
new_sum[threadIdx.y] = 0.0f;
float old_max = -__FLT_MAX__;
float new_max = -__FLT_MAX__;
float new_sum = 0.0f;
__shared__ float shareV[BLOCK_DIM_y];
__shared__ float out[BLOCK_DIM_y];
int phNumD = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
int extra = d % BLOCK_DIM_y;
int step = (d - extra) / BLOCK_DIM_y;
__shared__ float shareQ_times_K[BLOCK_DIM_y];
for (int phn = 0; phn < N; phn++) {
shareV[threadIdx.y] = 0.0f;
shareQ_times_K[threadIdx.y] = 0.0f;
float sum_s = 0.0f;
for (int ind = 0; ind < phNumD; ind++) {
if (threadIdx.y + ind * BLOCK_DIM_y < d) {
shareQ_times_K[threadIdx.y] =
inputQ[i * d + threadIdx.y + ind * BLOCK_DIM_y] *
inputK[phn * d + threadIdx.y + ind * BLOCK_DIM_y];
} else {
shareQ_times_K[threadIdx.y] = 0.0f;
if (threadIdx.y < extra) {
for (int ind = threadIdx.y * (step + 1);
ind < (threadIdx.y + 1) * (step + 1); ind++) {
shareQ_times_K[threadIdx.y] +=
inputQ[i * d + ind] * inputK[phn * d + ind];
}
__syncthreads();
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip = strip / 2) {
if (threadIdx.y < strip) {
shareQ_times_K[threadIdx.y] +=
shareQ_times_K[threadIdx.y + strip];
}
__syncthreads();
}
sum_s += shareQ_times_K[0];
__syncthreads();
}
shareQ_times_K[threadIdx.y] = sum_s;
if (phd < d) {
shareV[threadIdx.y] = inputV[phn * d + phd];
}
__syncthreads();
if (new_max[threadIdx.y] > sum_s) {
new_sum[threadIdx.y] =
new_sum[threadIdx.y] + __expf(sum_s - new_max[threadIdx.y]);
} else {
new_sum[threadIdx.y] =
1.0f +
new_sum[threadIdx.y] * __expf(new_max[threadIdx.y] - sum_s);
new_max[threadIdx.y] = sum_s;
for (int ind = extra * (step + 1) + (threadIdx.y - extra) * step;
ind < extra * (step + 1) + (threadIdx.y - extra + 1) * step;
ind++) {
shareQ_times_K[threadIdx.y] +=
inputQ[i * d + ind] * inputK[phn * d + ind];
}
}
__syncthreads();
for (int strip = BLOCK_DIM_y / 8; strip > 0; strip = strip / 8) {
if (threadIdx.y < strip) {
for (int id = 1; id < 8; id++) {
shareQ_times_K[threadIdx.y] +=
shareQ_times_K[threadIdx.y + id * strip];
}
}
__syncthreads();
}
sum_s = shareQ_times_K[0] + shareQ_times_K[1];
//__syncthreads();
shareQ_times_K[threadIdx.y] =
__expf(shareQ_times_K[threadIdx.y] - new_max[threadIdx.y]);
if (new_max > sum_s) {
new_sum = new_sum + __expf(sum_s - new_max);
} else {
new_sum = 1.0f + new_sum * __expf(new_max - sum_s);
new_max = sum_s;
}
__syncthreads();
//__syncthreads();
sum_s = __expf(sum_s - new_max);
//__syncthreads();
if (phn == 0) {
out[threadIdx.y] =
shareQ_times_K[threadIdx.y] * shareV[threadIdx.y];
out[threadIdx.y] = sum_s * inputV[phn * d + phd];
} else {
out[threadIdx.y] =
__expf(old_max[threadIdx.y] - new_max[threadIdx.y]) *
out[threadIdx.y] +
shareQ_times_K[threadIdx.y] * shareV[threadIdx.y];
out[threadIdx.y] = __expf(old_max - new_max) * out[threadIdx.y] +
sum_s * inputV[phn * d + phd];
}
old_max[threadIdx.y] = new_max[threadIdx.y];
old_max = new_max;
__syncthreads();
//__syncthreads();
}
__syncthreads();
//__syncthreads();
if (phd < d)
output[i * d + phd] =
out[threadIdx.y] * __fdividef(1.0F, new_sum[threadIdx.y]);
output[i * d + phd] = out[threadIdx.y] * __fdividef(1.0F, new_sum);
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
@ -100,48 +87,20 @@ void attentionKernel(const float *inputQ, const float *inputK,
int num_block_x = N;
if (d > 1023) {
if (d > 128) {
int BLOCK_DIM_y = 1024;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<1024>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 511) {
int BLOCK_DIM_y = 512;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<512>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 255) {
int BLOCK_DIM_y = 256;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<256>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 127) {
} else if (d > 16) {
int BLOCK_DIM_y = 128;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<128>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 63) {
int BLOCK_DIM_y = 64;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<64>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 31) {
int BLOCK_DIM_y = 32;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<32>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else {
int BLOCK_DIM_y = 16;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;