forked from jiuyuan/InfiniTensor
2D block, share S
This commit is contained in:
parent
f1dc440a3c
commit
ec391674ac
|
@ -1,104 +1,104 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include <cub/cub.cuh>
|
||||
struct __align__(8) MD {
|
||||
float max_tmp;
|
||||
float sum_tmp;
|
||||
};
|
||||
__device__ __forceinline__ MD reduce_md_op(MD a, MD b) {
|
||||
bool a_bigger = (a.max_tmp > b.max_tmp);
|
||||
MD bigger = a_bigger ? a : b;
|
||||
MD smaller = a_bigger ? b : a;
|
||||
MD res;
|
||||
res.sum_tmp = bigger.sum_tmp +
|
||||
smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp);
|
||||
res.max_tmp = bigger.max_tmp;
|
||||
return res;
|
||||
}
|
||||
|
||||
#define BLOCK_DIM_x 2
|
||||
#define BLOCK_DIM_y 2
|
||||
#define max_function(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
template <int BLOCK_DIM>
|
||||
__launch_bounds__(BLOCK_DIM) __global__
|
||||
void _attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, float *inputS, int N, int d,
|
||||
float *output) {
|
||||
MD md_partial;
|
||||
md_partial.max_tmp = -__FLT_MAX__;
|
||||
md_partial.sum_tmp = 0.0f;
|
||||
int phNumN = (N + BLOCK_DIM - 1) / BLOCK_DIM;
|
||||
int i = blockIdx.x; // i must < N
|
||||
for (int phn = 0; phn < phNumN; phn++) {
|
||||
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, int N, int d,
|
||||
float *output) {
|
||||
int i = blockIdx.x; // i must < N,Q[i]
|
||||
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
|
||||
int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||
__shared__ float old_max;
|
||||
__shared__ float new_max;
|
||||
__shared__ float new_sum;
|
||||
old_max = -__FLT_MAX__;
|
||||
new_max = -__FLT_MAX__;
|
||||
new_sum = 0.0f;
|
||||
__shared__ float block_sum[BLOCK_DIM_x];
|
||||
__shared__ float block_max[BLOCK_DIM_x];
|
||||
block_max[threadIdx.x] = -__FLT_MAX__;
|
||||
block_sum[threadIdx.x] = 0.0f;
|
||||
|
||||
int j = threadIdx.x + phn * BLOCK_DIM;
|
||||
MD md_input;
|
||||
__shared__ float inputS[BLOCK_DIM_x];
|
||||
|
||||
output[i * d + phd] = 0.0f;
|
||||
for (int phn = 0; phn < phNumN; phn++) {
|
||||
int j = threadIdx.x + phn * BLOCK_DIM_x;
|
||||
if (j < N) {
|
||||
float sum_s = 0.0f;
|
||||
float sum_s = 0;
|
||||
for (int index = 0; index < d; index++) {
|
||||
sum_s += inputQ[i * d + index] * inputK[j * d + index];
|
||||
}
|
||||
inputS[i * N + j] = sum_s;
|
||||
// printf("S--%d:%.4e\n",i * N + j,inputS[i * N + j]);
|
||||
md_input.max_tmp = sum_s;
|
||||
md_input.sum_tmp = 1.0f;
|
||||
inputS[threadIdx.x] = sum_s;
|
||||
block_max[threadIdx.x] = sum_s;
|
||||
block_sum[threadIdx.x] = 1.0f;
|
||||
} else {
|
||||
md_input.max_tmp = -__FLT_MAX__;
|
||||
md_input.sum_tmp = 0.0f;
|
||||
inputS[threadIdx.x] = 0.0f;
|
||||
block_max[threadIdx.x] = -__FLT_MAX__;
|
||||
block_sum[threadIdx.x] = 0.0f;
|
||||
}
|
||||
md_partial = reduce_md_op(md_partial, md_input);
|
||||
}
|
||||
typedef cub::BlockReduce<MD, BLOCK_DIM> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
__shared__ MD md_total;
|
||||
MD md_block = BlockReduce(temp_storage).Reduce(md_partial, reduce_md_op);
|
||||
if (threadIdx.x ==
|
||||
0) { // must set threadIdx.x = 0 write the output to memory
|
||||
md_total = md_block;
|
||||
}
|
||||
__syncthreads();
|
||||
// printf("max:%.4e\n",md_total.max_tmp);
|
||||
for (int phn = 0; threadIdx.x + phn * BLOCK_DIM < N; phn++) {
|
||||
int j = threadIdx.x + phn * BLOCK_DIM;
|
||||
inputS[i * N + j] = __expf(inputS[i * N + j] - md_total.max_tmp) *
|
||||
__fdividef(1.0F, md_total.sum_tmp);
|
||||
// printf("S:%.4e\n",inputS[i * N + j]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int phd = 0; threadIdx.x + phd * BLOCK_DIM < d; phd++) {
|
||||
int j = threadIdx.x + phd * BLOCK_DIM;
|
||||
__syncthreads();
|
||||
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) {
|
||||
if (threadIdx.x < strip) {
|
||||
if (block_max[threadIdx.x] > block_max[threadIdx.x + strip]) {
|
||||
block_sum[threadIdx.x] =
|
||||
block_sum[threadIdx.x] +
|
||||
block_sum[threadIdx.x + strip] *
|
||||
__expf(block_max[threadIdx.x + strip] -
|
||||
block_max[threadIdx.x]);
|
||||
} else {
|
||||
block_sum[threadIdx.x] =
|
||||
block_sum[threadIdx.x + strip] +
|
||||
block_sum[threadIdx.x] *
|
||||
__expf(block_max[threadIdx.x] -
|
||||
block_max[threadIdx.x + strip]);
|
||||
block_max[threadIdx.x] = block_max[threadIdx.x + strip];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (new_max > block_max[0]) {
|
||||
new_sum =
|
||||
new_sum + block_sum[0] * __expf(block_max[0] - new_max);
|
||||
} else {
|
||||
new_sum =
|
||||
block_sum[0] + new_sum * __expf(new_max - block_max[0]);
|
||||
new_max = block_max[0];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
inputS[threadIdx.x] = __expf(inputS[threadIdx.x] - new_max);
|
||||
__syncthreads();
|
||||
float sum_o = 0;
|
||||
for (int index = 0; index < N; index++) {
|
||||
sum_o += inputS[i * N + index] * inputV[index * d + j];
|
||||
if (phd < d) {
|
||||
for (int index = 0; index < BLOCK_DIM_x; index++) {
|
||||
if (index + phn * BLOCK_DIM_x < N) {
|
||||
sum_o += inputS[index] *
|
||||
inputV[(index + phn * BLOCK_DIM_x) * d + phd];
|
||||
}
|
||||
}
|
||||
output[i * d + phd] =
|
||||
__expf(old_max - new_max) * output[i * d + phd] + sum_o;
|
||||
old_max = new_max;
|
||||
}
|
||||
output[i * d + j] = sum_o;
|
||||
//__syncthreads();
|
||||
}
|
||||
if (phd < d)
|
||||
output[i * d + phd] = output[i * d + phd] * __fdividef(1.0F, new_sum);
|
||||
}
|
||||
namespace infini {
|
||||
void attentionKernel(const float *inputQ, const float *inputK,
|
||||
const float *inputV, int N, int d, float *output) {
|
||||
|
||||
float *inputS;
|
||||
cudaMalloc((void **)&inputS, N * N * sizeof(float));
|
||||
int nd = max_function(N, d);
|
||||
if (nd > 1023) {
|
||||
_attentionKernel<1024>
|
||||
<<<N, 1024>>>(inputQ, inputK, inputV, inputS, N, d, output);
|
||||
} else if (nd > 511) {
|
||||
_attentionKernel<512>
|
||||
<<<N, 512>>>(inputQ, inputK, inputV, inputS, N, d, output);
|
||||
} else if (nd > 255) {
|
||||
_attentionKernel<256>
|
||||
<<<N, 256>>>(inputQ, inputK, inputV, inputS, N, d, output);
|
||||
} else if (nd > 63) {
|
||||
_attentionKernel<64>
|
||||
<<<N, 64>>>(inputQ, inputK, inputV, inputS, N, d, output);
|
||||
} else if (nd > 15) {
|
||||
_attentionKernel<16>
|
||||
<<<N, 16>>>(inputQ, inputK, inputV, inputS, N, d, output);
|
||||
} else {
|
||||
_attentionKernel<8>
|
||||
<<<N, 8>>>(inputQ, inputK, inputV, inputS, N, d, output);
|
||||
}
|
||||
int num_block_x = N;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
int share_mem = (3 * BLOCK_DIM_x + 3) * sizeof(float);
|
||||
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
|
||||
N, d, output);
|
||||
}
|
||||
} // namespace infini
|
Loading…
Reference in New Issue