forked from jiuyuan/InfiniTensor
modified attention.cu,BLOCK_DIM_x must leq 32
This commit is contained in:
parent
ec391674ac
commit
b640ab1689
|
@ -1,7 +1,7 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
|
|
||||||
#define BLOCK_DIM_x 2
|
#define BLOCK_DIM_x 8 // BLOCK_DIM_x must <= 32
|
||||||
#define BLOCK_DIM_y 2
|
#define BLOCK_DIM_y 128
|
||||||
#define max_function(a, b) ((a) > (b) ? (a) : (b))
|
#define max_function(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
|
||||||
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
|
__global__ void _attentionKernel(const float *inputQ, const float *inputK,
|
||||||
|
@ -10,84 +10,117 @@ __global__ void _attentionKernel(const float *inputQ, const float *inputK,
|
||||||
int i = blockIdx.x; // i must < N,Q[i]
|
int i = blockIdx.x; // i must < N,Q[i]
|
||||||
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
|
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
|
||||||
int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int phNumN = (N + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
__shared__ float old_max;
|
__shared__ float old_max[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
__shared__ float new_max;
|
__shared__ float new_max[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
__shared__ float new_sum;
|
__shared__ float new_sum[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
old_max = -__FLT_MAX__;
|
old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
|
||||||
new_max = -__FLT_MAX__;
|
new_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
|
||||||
new_sum = 0.0f;
|
new_sum[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
__shared__ float block_sum[BLOCK_DIM_x];
|
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
__shared__ float block_max[BLOCK_DIM_x];
|
__shared__ float block_max[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
block_max[threadIdx.x] = -__FLT_MAX__;
|
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
|
||||||
block_sum[threadIdx.x] = 0.0f;
|
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
|
|
||||||
__shared__ float inputS[BLOCK_DIM_x];
|
__shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
|
|
||||||
output[i * d + phd] = 0.0f;
|
__syncthreads();
|
||||||
for (int phn = 0; phn < phNumN; phn++) {
|
for (int phn = 0; phn < phNumN; phn++) {
|
||||||
int j = threadIdx.x + phn * BLOCK_DIM_x;
|
int j = threadIdx.x + phn * BLOCK_DIM_x;
|
||||||
if (j < N) {
|
inputS[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
|
block_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
|
||||||
|
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
|
|
||||||
|
if (j < N && phd < d) {
|
||||||
float sum_s = 0;
|
float sum_s = 0;
|
||||||
for (int index = 0; index < d; index++) {
|
for (int index = 0; index < d; index++) {
|
||||||
sum_s += inputQ[i * d + index] * inputK[j * d + index];
|
sum_s += inputQ[i * d + index] * inputK[j * d + index];
|
||||||
}
|
}
|
||||||
inputS[threadIdx.x] = sum_s;
|
inputS[threadIdx.x][threadIdx.y] = sum_s;
|
||||||
block_max[threadIdx.x] = sum_s;
|
block_max[threadIdx.x][threadIdx.y] = sum_s;
|
||||||
block_sum[threadIdx.x] = 1.0f;
|
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
|
||||||
} else {
|
|
||||||
inputS[threadIdx.x] = 0.0f;
|
|
||||||
block_max[threadIdx.x] = -__FLT_MAX__;
|
|
||||||
block_sum[threadIdx.x] = 0.0f;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) {
|
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip = strip / 2) {
|
||||||
if (threadIdx.x < strip) {
|
if (threadIdx.x < strip) {
|
||||||
if (block_max[threadIdx.x] > block_max[threadIdx.x + strip]) {
|
if (block_max[threadIdx.x][threadIdx.y] >
|
||||||
block_sum[threadIdx.x] =
|
block_max[threadIdx.x + strip][threadIdx.y]) {
|
||||||
block_sum[threadIdx.x] +
|
block_sum[threadIdx.x][threadIdx.y] =
|
||||||
block_sum[threadIdx.x + strip] *
|
block_sum[threadIdx.x][threadIdx.y] +
|
||||||
__expf(block_max[threadIdx.x + strip] -
|
block_sum[threadIdx.x + strip][threadIdx.y] *
|
||||||
block_max[threadIdx.x]);
|
__expf(block_max[threadIdx.x + strip][threadIdx.y] -
|
||||||
|
block_max[threadIdx.x][threadIdx.y]);
|
||||||
} else {
|
} else {
|
||||||
block_sum[threadIdx.x] =
|
block_sum[threadIdx.x][threadIdx.y] =
|
||||||
block_sum[threadIdx.x + strip] +
|
block_sum[threadIdx.x + strip][threadIdx.y] +
|
||||||
block_sum[threadIdx.x] *
|
block_sum[threadIdx.x][threadIdx.y] *
|
||||||
__expf(block_max[threadIdx.x] -
|
__expf(block_max[threadIdx.x][threadIdx.y] -
|
||||||
block_max[threadIdx.x + strip]);
|
block_max[threadIdx.x + strip][threadIdx.y]);
|
||||||
block_max[threadIdx.x] = block_max[threadIdx.x + strip];
|
block_max[threadIdx.x][threadIdx.y] =
|
||||||
|
block_max[threadIdx.x + strip][threadIdx.y];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (threadIdx.x == 0) {
|
if (j < N && phd < d) {
|
||||||
if (new_max > block_max[0]) {
|
if (new_max[threadIdx.x][threadIdx.y] > block_max[0][threadIdx.y]) {
|
||||||
new_sum =
|
new_sum[threadIdx.x][threadIdx.y] =
|
||||||
new_sum + block_sum[0] * __expf(block_max[0] - new_max);
|
new_sum[threadIdx.x][threadIdx.y] +
|
||||||
|
block_sum[0][threadIdx.y] *
|
||||||
|
__expf(block_max[0][threadIdx.y] -
|
||||||
|
new_max[threadIdx.x][threadIdx.y]);
|
||||||
} else {
|
} else {
|
||||||
new_sum =
|
new_sum[threadIdx.x][threadIdx.y] =
|
||||||
block_sum[0] + new_sum * __expf(new_max - block_max[0]);
|
block_sum[0][threadIdx.y] +
|
||||||
new_max = block_max[0];
|
new_sum[threadIdx.x][threadIdx.y] *
|
||||||
|
__expf(new_max[threadIdx.x][threadIdx.y] -
|
||||||
|
block_max[0][threadIdx.y]);
|
||||||
|
new_max[threadIdx.x][threadIdx.y] = block_max[0][threadIdx.y];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
inputS[threadIdx.x] = __expf(inputS[threadIdx.x] - new_max);
|
|
||||||
|
if (j < N && phd < d) {
|
||||||
|
inputS[threadIdx.x][threadIdx.y] =
|
||||||
|
__expf(inputS[threadIdx.x][threadIdx.y] -
|
||||||
|
new_max[threadIdx.x][threadIdx.y]);
|
||||||
|
} else {
|
||||||
|
inputS[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
float sum_o = 0;
|
|
||||||
if (phd < d) {
|
if (phd < d) {
|
||||||
|
float sum_o = 0.0f;
|
||||||
for (int index = 0; index < BLOCK_DIM_x; index++) {
|
for (int index = 0; index < BLOCK_DIM_x; index++) {
|
||||||
if (index + phn * BLOCK_DIM_x < N) {
|
if (index + phn * BLOCK_DIM_x < N) {
|
||||||
sum_o += inputS[index] *
|
sum_o += inputS[index][threadIdx.y] *
|
||||||
inputV[(index + phn * BLOCK_DIM_x) * d + phd];
|
inputV[(index + phn * BLOCK_DIM_x) * d + phd];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output[i * d + phd] =
|
if (phn == 0) {
|
||||||
__expf(old_max - new_max) * output[i * d + phd] + sum_o;
|
output[i * d + phd] = sum_o;
|
||||||
old_max = new_max;
|
} else {
|
||||||
|
output[i * d + phd] =
|
||||||
|
__expf(old_max[threadIdx.x][threadIdx.y] -
|
||||||
|
new_max[threadIdx.x][threadIdx.y]) *
|
||||||
|
output[i * d + phd] +
|
||||||
|
sum_o;
|
||||||
|
}
|
||||||
|
|
||||||
|
old_max[threadIdx.x][threadIdx.y] =
|
||||||
|
new_max[threadIdx.x][threadIdx.y];
|
||||||
|
} else {
|
||||||
|
old_max[threadIdx.x][threadIdx.y] = -__FLT_MAX__;
|
||||||
}
|
}
|
||||||
//__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
if (phd < d)
|
if (phd < d)
|
||||||
output[i * d + phd] = output[i * d + phd] * __fdividef(1.0F, new_sum);
|
output[i * d + phd] =
|
||||||
|
output[i * d + phd] *
|
||||||
|
__fdividef(1.0F, new_sum[threadIdx.x][threadIdx.y]);
|
||||||
}
|
}
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void attentionKernel(const float *inputQ, const float *inputK,
|
void attentionKernel(const float *inputQ, const float *inputK,
|
||||||
|
@ -97,7 +130,8 @@ void attentionKernel(const float *inputQ, const float *inputK,
|
||||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
int share_mem = (3 * BLOCK_DIM_x + 3) * sizeof(float);
|
int share_mem =
|
||||||
|
(3 * BLOCK_DIM_x + 3 * BLOCK_DIM_x) * BLOCK_DIM_y * sizeof(float);
|
||||||
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
|
_attentionKernel<<<grid_dim, block_dim, share_mem>>>(inputQ, inputK, inputV,
|
||||||
N, d, output);
|
N, d, output);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue