forked from jiuyuan/InfiniTensor
modified logic
This commit is contained in:
parent
3629881dfa
commit
54f4265296
|
@ -1,123 +1,187 @@
|
||||||
#include "cuda/cuda_common.h"
|
#include "cuda/cuda_common.h"
|
||||||
#include <cub/block/block_reduce.cuh>
|
|
||||||
#define max_function(a, b) ((a) > (b) ? (a) : (b))
|
|
||||||
|
|
||||||
template <int BLOCK_DIM_x>
|
template <int BLOCK_DIM_x, int BLOCK_DIM_y>
|
||||||
__launch_bounds__(BLOCK_DIM_x) __global__
|
__global__ void _attentionKernel(const float *__restrict inputQ,
|
||||||
void _attentionKernel(const float *__restrict inputQ,
|
const float *__restrict inputK,
|
||||||
const float *__restrict inputK,
|
const float *__restrict inputV, int N, int d,
|
||||||
const float *__restrict inputV, int N, int d,
|
float *__restrict output) {
|
||||||
float *__restrict output) {
|
|
||||||
int i = blockIdx.y; // i must < N,Q[i]
|
int i = blockIdx.y; // i must < N,Q[i]
|
||||||
int phd = threadIdx.x + blockIdx.x * blockDim.x; // V[:,d]
|
int phd = threadIdx.x + blockIdx.x * blockDim.x; // V[:,d]
|
||||||
|
|
||||||
float old_max = -__FLT_MAX__;
|
int phNumN = (N + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
|
||||||
float new_max = -__FLT_MAX__;
|
__shared__ float inputS[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
float new_sum = 0.0f;
|
float newMax;
|
||||||
|
float oldMax;
|
||||||
|
float newSum;
|
||||||
|
|
||||||
__shared__ float out[BLOCK_DIM_x];
|
newMax = -__FLT_MAX__;
|
||||||
|
oldMax = -__FLT_MAX__;
|
||||||
|
newSum = 0.0f;
|
||||||
|
|
||||||
|
float out;
|
||||||
|
out = 0.0f;
|
||||||
|
//---------
|
||||||
|
__shared__ float block_sum[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
|
|
||||||
|
__shared__ float sum_partial[BLOCK_DIM_x][BLOCK_DIM_y];
|
||||||
int extra = d % BLOCK_DIM_x;
|
int extra = d % BLOCK_DIM_x;
|
||||||
int step = (d - extra) / BLOCK_DIM_x;
|
int step = (d - extra) / BLOCK_DIM_x;
|
||||||
out[threadIdx.x] = 0.0f;
|
for (int phn = 0; phn < phNumN; phn++) {
|
||||||
__shared__ float sum_s;
|
|
||||||
for (int phn = 0; phn < N; phn++) {
|
|
||||||
float sum_partial = 0.0f;
|
|
||||||
|
|
||||||
|
int j = threadIdx.y + phn * BLOCK_DIM_y;
|
||||||
|
|
||||||
|
float sum_r = 0.0f;
|
||||||
|
__syncthreads();
|
||||||
if (threadIdx.x < extra) {
|
if (threadIdx.x < extra) {
|
||||||
for (int ind = threadIdx.x * (step + 1);
|
for (int ind = threadIdx.x * (step + 1);
|
||||||
ind < (threadIdx.x + 1) * (step + 1); ind++) {
|
ind < (threadIdx.x + 1) * (step + 1); ind++) {
|
||||||
sum_partial += inputQ[i * d + ind] * inputK[phn * d + ind];
|
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int ind = extra * (step + 1) + (threadIdx.x - extra) * step;
|
for (int ind = extra * (step + 1) + (threadIdx.x - extra) * step;
|
||||||
ind < extra * (step + 1) + (threadIdx.x - extra + 1) * step;
|
ind < extra * (step + 1) + (threadIdx.x - extra + 1) * step;
|
||||||
ind++) {
|
ind++) {
|
||||||
sum_partial += inputQ[i * d + ind] * inputK[phn * d + ind];
|
sum_r += inputQ[i * d + ind] * inputK[j * d + ind];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
typedef cub::BlockReduce<float, BLOCK_DIM_x> BlockReduce;
|
if (j < N) {
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
sum_partial[threadIdx.x][threadIdx.y] = sum_r;
|
||||||
float block_sum =
|
|
||||||
BlockReduce(temp_storage).Reduce(sum_partial, cub::Sum());
|
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
|
||||||
sum_s = block_sum;
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (new_max > sum_s) {
|
|
||||||
new_sum = new_sum + __expf(sum_s - new_max);
|
|
||||||
} else {
|
} else {
|
||||||
new_sum = 1.0f + new_sum * __expf(new_max - sum_s);
|
sum_partial[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
new_max = sum_s;
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (int strip = BLOCK_DIM_x / 2; strip > 0; strip /= 2) {
|
||||||
|
if (threadIdx.x < strip) {
|
||||||
|
sum_partial[threadIdx.x][threadIdx.y] +=
|
||||||
|
sum_partial[threadIdx.x + strip][threadIdx.y];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
float sum_s = sum_partial[0][threadIdx.y];
|
||||||
|
if (j < N) {
|
||||||
|
|
||||||
|
block_sum[threadIdx.x][threadIdx.y] = 1.0f;
|
||||||
|
} else {
|
||||||
|
|
||||||
|
sum_partial[0][threadIdx.y] = -__FLT_MAX__;
|
||||||
|
block_sum[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
|
||||||
|
if (threadIdx.y < strip) {
|
||||||
|
if (sum_partial[0][threadIdx.y] >
|
||||||
|
sum_partial[0][threadIdx.y + strip]) {
|
||||||
|
block_sum[threadIdx.x][threadIdx.y] =
|
||||||
|
block_sum[threadIdx.x][threadIdx.y] +
|
||||||
|
block_sum[threadIdx.x][threadIdx.y + strip] *
|
||||||
|
__expf(sum_partial[0][threadIdx.y + strip] -
|
||||||
|
sum_partial[0][threadIdx.y]);
|
||||||
|
} else {
|
||||||
|
block_sum[threadIdx.x][threadIdx.y] =
|
||||||
|
block_sum[threadIdx.x][threadIdx.y + strip] +
|
||||||
|
block_sum[threadIdx.x][threadIdx.y] *
|
||||||
|
__expf(sum_partial[0][threadIdx.y] -
|
||||||
|
sum_partial[0][threadIdx.y + strip]);
|
||||||
|
sum_partial[0][threadIdx.y] =
|
||||||
|
sum_partial[0][threadIdx.y + strip];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (newMax > sum_partial[0][0]) {
|
||||||
|
newSum = newSum + block_sum[threadIdx.x][0] *
|
||||||
|
__expf(sum_partial[0][0] - newMax);
|
||||||
|
} else {
|
||||||
|
newSum = block_sum[threadIdx.x][0] +
|
||||||
|
newSum * __expf(newMax - sum_partial[0][0]);
|
||||||
|
newMax = sum_partial[0][0];
|
||||||
}
|
}
|
||||||
|
|
||||||
sum_s = __expf(sum_s - new_max);
|
if (j < N && phd < d) {
|
||||||
|
inputS[threadIdx.x][threadIdx.y] =
|
||||||
out[threadIdx.x] = __expf(old_max - new_max) * out[threadIdx.x] +
|
__expf(sum_s - newMax) *
|
||||||
sum_s * inputV[phn * d + phd];
|
inputV[(threadIdx.y + phn * BLOCK_DIM_y) * d + phd];
|
||||||
|
} else {
|
||||||
old_max = new_max;
|
inputS[threadIdx.x][threadIdx.y] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (int strip = BLOCK_DIM_y / 2; strip > 0; strip /= 2) {
|
||||||
|
if (threadIdx.y < strip) {
|
||||||
|
inputS[threadIdx.x][threadIdx.y] +=
|
||||||
|
inputS[threadIdx.x][threadIdx.y + strip];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (j < N && phd < d) {
|
||||||
|
out = __expf(oldMax - newMax) * out + inputS[threadIdx.x][0];
|
||||||
|
}
|
||||||
|
oldMax = newMax;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (phd < d)
|
if (threadIdx.y + (phNumN - 1) * BLOCK_DIM_y < N && phd < d) {
|
||||||
output[i * d + phd] = out[threadIdx.x] * __fdividef(1.0F, new_sum);
|
output[i * d + phd] = out * __fdividef(1.0F, newSum);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void attentionKernel(const float *inputQ, const float *inputK,
|
void attentionKernel(const float *inputQ, const float *inputK,
|
||||||
const float *inputV, int N, int d, float *output) {
|
const float *inputV, int N, int d, float *output) {
|
||||||
|
|
||||||
int num_block_y = N;
|
int num_block_y = N;
|
||||||
|
|
||||||
if (d > 512) {
|
if (d > 512) {
|
||||||
int BLOCK_DIM_x = 1024;
|
int BLOCK_DIM_x = 1024;
|
||||||
|
int BLOCK_DIM_y = 1;
|
||||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
dim3 block_dim(BLOCK_DIM_x, 1, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
_attentionKernel<1024>
|
_attentionKernel<1024, 1>
|
||||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
} else if (d > 256) {
|
} else if (d > 256) {
|
||||||
int BLOCK_DIM_x = 512;
|
int BLOCK_DIM_x = 512;
|
||||||
|
int BLOCK_DIM_y = 2;
|
||||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
dim3 block_dim(BLOCK_DIM_x, 1, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
_attentionKernel<512>
|
_attentionKernel<512, 2>
|
||||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
} else if (d > 128) {
|
} else if (d > 128) {
|
||||||
int BLOCK_DIM_x = 256;
|
int BLOCK_DIM_x = 256;
|
||||||
|
int BLOCK_DIM_y = 4;
|
||||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
dim3 block_dim(BLOCK_DIM_x, 1, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
_attentionKernel<256>
|
_attentionKernel<256, 4>
|
||||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
} else if (d > 64) {
|
} else if (d > 64) {
|
||||||
int BLOCK_DIM_x = 128;
|
int BLOCK_DIM_x = 128;
|
||||||
|
int BLOCK_DIM_y = 8;
|
||||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
dim3 block_dim(BLOCK_DIM_x, 1, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
_attentionKernel<128>
|
_attentionKernel<128, 8>
|
||||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
} else if (d > 32) {
|
} else if (d > 32) {
|
||||||
int BLOCK_DIM_x = 64;
|
int BLOCK_DIM_x = 64;
|
||||||
|
int BLOCK_DIM_y = 16;
|
||||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
dim3 block_dim(BLOCK_DIM_x, 1, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
_attentionKernel<512>
|
_attentionKernel<64, 16>
|
||||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
} else if (d > 16) {
|
} else if (d > 16) {
|
||||||
int BLOCK_DIM_x = 32;
|
int BLOCK_DIM_x = 32;
|
||||||
|
int BLOCK_DIM_y = 32;
|
||||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
dim3 block_dim(BLOCK_DIM_x, 1, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
_attentionKernel<32>
|
_attentionKernel<32, 32>
|
||||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
} else {
|
} else {
|
||||||
int BLOCK_DIM_x = 16;
|
int BLOCK_DIM_x = 16;
|
||||||
|
int BLOCK_DIM_y = 64;
|
||||||
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
|
||||||
dim3 block_dim(BLOCK_DIM_x, 1, 1);
|
dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
|
||||||
dim3 grid_dim(num_block_x, num_block_y, 1);
|
dim3 grid_dim(num_block_x, num_block_y, 1);
|
||||||
_attentionKernel<16>
|
_attentionKernel<16, 64>
|
||||||
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue