modified threadIdx.y to threadIdx.x

This commit is contained in:
xgqdut2016 2023-10-08 11:33:16 +08:00
parent 56e2c87c9b
commit 79dd3364df
1 changed files with 34 additions and 34 deletions

View File

@ -2,49 +2,49 @@
#define max_function(a, b) ((a) > (b) ? (a) : (b))
template <int BLOCK_DIM_y>
__launch_bounds__(BLOCK_DIM_y) __global__
template <int BLOCK_DIM_x>
__launch_bounds__(BLOCK_DIM_x) __global__
void _attentionKernel(const float *__restrict inputQ,
const float *__restrict inputK,
const float *__restrict inputV, int N, int d,
float *__restrict output) {
int i = blockIdx.x; // i must < N,Q[i]
int phd = threadIdx.y + blockIdx.y * blockDim.y; // V[:,d]
int i = blockIdx.y; // i must < N,Q[i]
int phd = threadIdx.x + blockIdx.x * blockDim.x; // V[:,d]
float old_max = -__FLT_MAX__;
float new_max = -__FLT_MAX__;
float new_sum = 0.0f;
__shared__ float out[BLOCK_DIM_y];
__shared__ float out[BLOCK_DIM_x];
int extra = d % BLOCK_DIM_y;
int step = (d - extra) / BLOCK_DIM_y;
__shared__ float shareQ_times_K[BLOCK_DIM_y];
int extra = d % BLOCK_DIM_x;
int step = (d - extra) / BLOCK_DIM_x;
__shared__ float shareQ_times_K[BLOCK_DIM_x];
for (int phn = 0; phn < N; phn++) {
shareQ_times_K[threadIdx.y] = 0.0f;
shareQ_times_K[threadIdx.x] = 0.0f;
float sum_s = 0.0f;
if (threadIdx.y < extra) {
for (int ind = threadIdx.y * (step + 1);
ind < (threadIdx.y + 1) * (step + 1); ind++) {
shareQ_times_K[threadIdx.y] +=
if (threadIdx.x < extra) {
for (int ind = threadIdx.x * (step + 1);
ind < (threadIdx.x + 1) * (step + 1); ind++) {
shareQ_times_K[threadIdx.x] +=
inputQ[i * d + ind] * inputK[phn * d + ind];
}
} else {
for (int ind = extra * (step + 1) + (threadIdx.y - extra) * step;
ind < extra * (step + 1) + (threadIdx.y - extra + 1) * step;
for (int ind = extra * (step + 1) + (threadIdx.x - extra) * step;
ind < extra * (step + 1) + (threadIdx.x - extra + 1) * step;
ind++) {
shareQ_times_K[threadIdx.y] +=
shareQ_times_K[threadIdx.x] +=
inputQ[i * d + ind] * inputK[phn * d + ind];
}
}
__syncthreads();
for (int strip = BLOCK_DIM_y / 8; strip > 0; strip = strip / 8) {
if (threadIdx.y < strip) {
for (int strip = BLOCK_DIM_x / 8; strip > 0; strip = strip / 8) {
if (threadIdx.x < strip) {
for (int id = 1; id < 8; id++) {
shareQ_times_K[threadIdx.y] +=
shareQ_times_K[threadIdx.y + id * strip];
shareQ_times_K[threadIdx.x] +=
shareQ_times_K[threadIdx.x + id * strip];
}
}
__syncthreads();
@ -66,10 +66,10 @@ __launch_bounds__(BLOCK_DIM_y) __global__
//__syncthreads();
if (phn == 0) {
out[threadIdx.y] = sum_s * inputV[phn * d + phd];
out[threadIdx.x] = sum_s * inputV[phn * d + phd];
} else {
out[threadIdx.y] = __expf(old_max - new_max) * out[threadIdx.y] +
out[threadIdx.x] = __expf(old_max - new_max) * out[threadIdx.x] +
sum_s * inputV[phn * d + phd];
}
@ -79,35 +79,35 @@ __launch_bounds__(BLOCK_DIM_y) __global__
}
//__syncthreads();
if (phd < d)
output[i * d + phd] = out[threadIdx.y] * __fdividef(1.0F, new_sum);
output[i * d + phd] = out[threadIdx.x] * __fdividef(1.0F, new_sum);
}
namespace infini {
void attentionKernel(const float *inputQ, const float *inputK,
const float *inputV, int N, int d, float *output) {
int num_block_x = N;
int num_block_y = N;
if (d > 128) {
int BLOCK_DIM_y = 1024;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
int BLOCK_DIM_x = 1024;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, 1, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<1024>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else if (d > 16) {
int BLOCK_DIM_y = 128;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
int BLOCK_DIM_x = 128;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, 1, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<128>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
} else {
int BLOCK_DIM_y = 16;
int num_block_y = (d + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
dim3 block_dim(1, BLOCK_DIM_y, 1);
int BLOCK_DIM_x = 16;
int num_block_x = (d + BLOCK_DIM_x - 1) / BLOCK_DIM_x;
dim3 block_dim(BLOCK_DIM_x, 1, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
_attentionKernel<16>
<<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
}
}
} // namespace infini
} // namespace infini