2D block, share S

This commit is contained in:
xgqdut2016 2023-09-25 13:02:25 +08:00
parent f1dc440a3c
commit ec391674ac
1 changed files with 83 additions and 83 deletions

View File

@ -1,104 +1,104 @@
#include "cuda/cuda_common.h"
#include <cub/cub.cuh>
struct __align__(8) MD {
float max_tmp;
float sum_tmp;
};
__device__ __forceinline__ MD reduce_md_op(MD a, MD b) {
bool a_bigger = (a.max_tmp > b.max_tmp);
MD bigger = a_bigger ? a : b;
MD smaller = a_bigger ? b : a;
MD res;
res.sum_tmp = bigger.sum_tmp +
smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp);
res.max_tmp = bigger.max_tmp;
return res;
}
#define BLOCK_DIM_x 2
#define BLOCK_DIM_y 2
#define max_function(a, b) ((a) > (b) ? (a) : (b))
template <int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__
void _attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, float *inputS, int N, int d,
float *output) {
MD md_partial;
md_partial.max_tmp = -__FLT_MAX__;
md_partial.sum_tmp = 0.0f;
int phNumN = (N + BLOCK_DIM - 1) / BLOCK_DIM;
int i = blockIdx.x; // i must < N
for (int phn = 0; phn < phNumN; phn++) {
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d,
float *output) {
int i = blockIdx.x; // i must < N,Q[i]
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
__shared__ float old_max;
__shared__ float new_max;
__shared__ float new_sum;
old_max = -__FLT_MAX__;
new_max = -__FLT_MAX__;
new_sum = 0.0f;
__shared__ float block_sum[BLOCK_DIM_x];
__shared__ float block_max[BLOCK_DIM_x];
block_max[threadIdx.x] = -__FLT_MAX__;
block_sum[threadIdx.x] = 0.0f;
int j = threadIdx.x + phn * BLOCK_DIM;
MD md_input;
__shared__ float inputS[BLOCK_DIM_x];
output[i * d + phd] = 0.0f;
for (int phn = 0; phn < phNumN; phn++) {
int j = threadIdx.x + phn * BLOCK_DIM_x;
if (j < N) {
float sum_s = 0.0f;
float sum_s = 0;
for (int index = 0; index < d; index++) {
sum_s += inputQ[i * d + index] * inputK[j * d + index];
}
inputS[i * N + j] = sum_s;
// printf("S--%d:%.4e\n",i * N + j,inputS[i * N + j]);
md_input.max_tmp = sum_s;
md_input.sum_tmp = 1.0f;
inputS[threadIdx.x] = sum_s;
block_max[threadIdx.x] = sum_s;
block_sum[threadIdx.x] = 1.0f;
} else {
md_input.max_tmp = -__FLT_MAX__;
md_input.sum_tmp = 0.0f;
inputS[threadIdx.x] = 0.0f;
block_max[threadIdx.x] = -__FLT_MAX__;
block_sum[threadIdx.x] = 0.0f;
}
md_partial = reduce_md_op(md_partial, md_input);
}
typedef cub::BlockReduce<MD, BLOCK_DIM> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ MD md_total;
MD md_block = BlockReduce(temp_storage).Reduce(md_partial, reduce_md_op);
if (threadIdx.x ==
0) { // must set threadIdx.x = 0 write the output to memory
md_total = md_block;
}
__syncthreads();
// printf("max:%.4e\n",md_total.max_tmp);
for (int phn = 0; threadIdx.x + phn * BLOCK_DIM < N; phn++) {
int j = threadIdx.x + phn * BLOCK_DIM;
inputS[i * N + j] = __expf(inputS[i * N + j] - md_total.max_tmp) *
__fdividef(1.0F, md_total.sum_tmp);
// printf("S:%.4e\n",inputS[i * N + j]);
}
__syncthreads();
for (int phd = 0; threadIdx.x + phd * BLOCK_DIM < d; phd++) {
int j = threadIdx.x + phd * BLOCK_DIM;
__syncthreads();
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) {
if (threadIdx.x < strip) {
if (block_max[threadIdx.x] > block_max[threadIdx.x + strip]) {
block_sum[threadIdx.x] =
block_sum[threadIdx.x] +
block_sum[threadIdx.x + strip] *
__expf(block_max[threadIdx.x + strip] -
block_max[threadIdx.x]);
} else {
block_sum[threadIdx.x] =
block_sum[threadIdx.x + strip] +
block_sum[threadIdx.x] *
__expf(block_max[threadIdx.x] -
block_max[threadIdx.x + strip]);
block_max[threadIdx.x] = block_max[threadIdx.x + strip];
}
}
}
__syncthreads();
if (threadIdx.x == 0) {
if (new_max > block_max[0]) {
new_sum =
new_sum + block_sum[0] * __expf(block_max[0] - new_max);
} else {
new_sum =
block_sum[0] + new_sum * __expf(new_max - block_max[0]);
new_max = block_max[0];
}
}
__syncthreads();
inputS[threadIdx.x] = __expf(inputS[threadIdx.x] - new_max);
__syncthreads();
float sum_o = 0;
for (int index = 0; index < N; index++) {
sum_o += inputS[i * N + index] * inputV[index * d + j];
if (phd < d) {
for (int index = 0; index < BLOCK_DIM_x; index++) {
if (index + phn * BLOCK_DIM_x < N) {
sum_o += inputS[index] *
inputV[(index + phn * BLOCK_DIM_x) * d + phd];
}
}
output[i * d + phd] =
__expf(old_max - new_max) * output[i * d + phd] + sum_o;
old_max = new_max;
}
output[i * d + j] = sum_o;
//__syncthreads();
}
if (phd < d)
output[i * d + phd] = output[i * d + phd] * __fdividef(1.0F, new_sum);
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output) {
float *inputS;
cudaMalloc((void **)&inputS, N * N * sizeof(float));
int nd = max_function(N, d);
if (nd > 1023) {
_attentionKernel<1024>
<<<N, 1024>>>(inputQ, inputK, inputV, inputS, N, d, output);
} else if (nd > 511) {
_attentionKernel<512>
<<<N, 512>>>(inputQ, inputK, inputV, inputS, N, d, output);
} else if (nd > 255) {
_attentionKernel<256>
<<<N, 256>>>(inputQ, inputK, inputV, inputS, N, d, output);
} else if (nd > 63) {
_attentionKernel<64>
<<<N, 64>>>(inputQ, inputK, inputV, inputS, N, d, output);
} else if (nd > 15) {
_attentionKernel<16>
<<<N, 16>>>(inputQ, inputK, inputV, inputS, N, d, output);
} else {
_attentionKernel<8>
<<<N, 8>>>(inputQ, inputK, inputV, inputS, N, d, output);
}
int num_block_x = N;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
int share_mem = (3 * BLOCK_DIM_x + 3) * sizeof(float);
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
N, d, output);
}
} // namespace infini