forked from jiuyuan/InfiniTensor
modified reduce,8ms
This commit is contained in:
parent
819484eda2
commit
56e2c87c9b
|
@ -11,88 +11,75 @@ __launch_bounds__(BLOCK_DIM_y) __global__
|
|||
int i = blockIdx.x; // i must < N,Q[i]
|
||||
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
|
||||
|
||||
__shared__ float old_max[BLOCK_DIM_y];
|
||||
__shared__ float new_max[BLOCK_DIM_y];
|
||||
__shared__ float new_sum[BLOCK_DIM_y];
|
||||
old_max[threadIdx.y] = -__FLT_MAX__;
|
||||
new_max[threadIdx.y] = -__FLT_MAX__;
|
||||
new_sum[threadIdx.y] = 0.0f;
|
||||
float old_max = -__FLT_MAX__;
|
||||
float new_max = -__FLT_MAX__;
|
||||
float new_sum = 0.0f;
|
||||
|
||||
__shared__ float shareV[BLOCK_DIM_y];
|
||||
__shared__ float out[BLOCK_DIM_y];
|
||||
|
||||
int phNumD = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
int extra = d % BLOCK_DIM_y;
|
||||
int step = (d - extra) / BLOCK_DIM_y;
|
||||
__shared__ float shareQ_times_K[BLOCK_DIM_y];
|
||||
|
||||
for (int phn = 0; phn < N; phn++) {
|
||||
shareV[threadIdx.y] = 0.0f;
|
||||
|
||||
shareQ_times_K[threadIdx.y] = 0.0f;
|
||||
float sum_s = 0.0f;
|
||||
for (int ind = 0; ind < phNumD; ind++) {
|
||||
if (threadIdx.y + ind * BLOCK_DIM_y < d) {
|
||||
shareQ_times_K[threadIdx.y] =
|
||||
inputQ[i * d + threadIdx.y + ind * BLOCK_DIM_y] *
|
||||
inputK[phn * d + threadIdx.y + ind * BLOCK_DIM_y];
|
||||
|
||||
} else {
|
||||
shareQ_times_K[threadIdx.y] = 0.0f;
|
||||
if (threadIdx.y < extra) {
|
||||
for (int ind = threadIdx.y * (step + 1);
|
||||
ind < (threadIdx.y + 1) * (step + 1); ind++) {
|
||||
shareQ_times_K[threadIdx.y] +=
|
||||
inputQ[i * d + ind] * inputK[phn * d + ind];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip = strip / 2) {
|
||||
if (threadIdx.y < strip) {
|
||||
shareQ_times_K[threadIdx.y] +=
|
||||
shareQ_times_K[threadIdx.y + strip];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
sum_s += shareQ_times_K[0];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
shareQ_times_K[threadIdx.y] = sum_s;
|
||||
|
||||
if (phd < d) {
|
||||
shareV[threadIdx.y] = inputV[phn * d + phd];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (new_max[threadIdx.y] > sum_s) {
|
||||
new_sum[threadIdx.y] =
|
||||
new_sum[threadIdx.y] + __expf(sum_s - new_max[threadIdx.y]);
|
||||
} else {
|
||||
new_sum[threadIdx.y] =
|
||||
1.0f +
|
||||
new_sum[threadIdx.y] * __expf(new_max[threadIdx.y] - sum_s);
|
||||
new_max[threadIdx.y] = sum_s;
|
||||
for (int ind = extra * (step + 1) + (threadIdx.y - extra) * step;
|
||||
ind < extra * (step + 1) + (threadIdx.y - extra + 1) * step;
|
||||
ind++) {
|
||||
shareQ_times_K[threadIdx.y] +=
|
||||
inputQ[i * d + ind] * inputK[phn * d + ind];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int strip = BLOCK_DIM_y / 8; strip > 0; strip = strip / 8) {
|
||||
if (threadIdx.y < strip) {
|
||||
for (int id = 1; id < 8; id++) {
|
||||
shareQ_times_K[threadIdx.y] +=
|
||||
shareQ_times_K[threadIdx.y + id * strip];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
sum_s = shareQ_times_K[0] + shareQ_times_K[1];
|
||||
//__syncthreads();
|
||||
|
||||
shareQ_times_K[threadIdx.y] =
|
||||
__expf(shareQ_times_K[threadIdx.y] - new_max[threadIdx.y]);
|
||||
if (new_max > sum_s) {
|
||||
new_sum = new_sum + __expf(sum_s - new_max);
|
||||
} else {
|
||||
new_sum = 1.0f + new_sum * __expf(new_max - sum_s);
|
||||
new_max = sum_s;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
//__syncthreads();
|
||||
|
||||
sum_s = __expf(sum_s - new_max);
|
||||
|
||||
//__syncthreads();
|
||||
|
||||
if (phn == 0) {
|
||||
out[threadIdx.y] =
|
||||
shareQ_times_K[threadIdx.y] * shareV[threadIdx.y];
|
||||
out[threadIdx.y] = sum_s * inputV[phn * d + phd];
|
||||
|
||||
} else {
|
||||
out[threadIdx.y] =
|
||||
__expf(old_max[threadIdx.y] - new_max[threadIdx.y]) *
|
||||
out[threadIdx.y] +
|
||||
shareQ_times_K[threadIdx.y] * shareV[threadIdx.y];
|
||||
out[threadIdx.y] = __expf(old_max - new_max) * out[threadIdx.y] +
|
||||
sum_s * inputV[phn * d + phd];
|
||||
}
|
||||
|
||||
old_max[threadIdx.y] = new_max[threadIdx.y];
|
||||
old_max = new_max;
|
||||
|
||||
__syncthreads();
|
||||
//__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
//__syncthreads();
|
||||
if (phd < d)
|
||||
output[i * d + phd] =
|
||||
out[threadIdx.y] * __fdividef(1.0F, new_sum[threadIdx.y]);
|
||||
output[i * d + phd] = out[threadIdx.y] * __fdividef(1.0F, new_sum);
|
||||
}
|
||||
namespace infini {
|
||||
void attentionKernel(const float *inputQ, const float *inputK,
|
||||
|
@ -100,48 +87,20 @@ void attentionKernel(const float *inputQ, const float *inputK,
|
|||
|
||||
int num_block_x = N;
|
||||
|
||||
if (d > 1023) {
|
||||
if (d > 128) {
|
||||
int BLOCK_DIM_y = 1024;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(1, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<1024>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 511) {
|
||||
int BLOCK_DIM_y = 512;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(1, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<512>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 255) {
|
||||
int BLOCK_DIM_y = 256;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(1, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<256>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 127) {
|
||||
} else if (d > 16) {
|
||||
int BLOCK_DIM_y = 128;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(1, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<128>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 63) {
|
||||
int BLOCK_DIM_y = 64;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(1, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<64>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else if (d > 31) {
|
||||
int BLOCK_DIM_y = 32;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
dim3 block_dim(1, BLOCK_DIM_y, 1);
|
||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||
_attentionKernel<32>
|
||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||
} else {
|
||||
int BLOCK_DIM_y = 16;
|
||||
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||
|
|
Loading…
Reference in New Issue