This commit is contained in:
xiaonans 2023-11-28 16:29:48 +08:00
parent 2fb1c8cf32
commit 815d0ebf44
1 changed files with 52 additions and 463 deletions

View File

@ -4,7 +4,7 @@
#define BLOCKSIZE WARP_SIZE #define BLOCKSIZE WARP_SIZE
#define SEQ_UNIT 32 #define SEQ_UNIT 32
__global__ void _attention_kvcache_kernel(float* input_k_cache, __global__ void _attention_kvcache_kernel_64(float* input_k_cache,
float* input_v_cache, float* input_v_cache,
float* input_q, float* input_q,
float* input_k, float* input_k,
@ -113,415 +113,42 @@ __global__ void _attention_kvcache_kernel(float* input_k_cache,
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0]; (float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
} }
__global__ void _attention_kvcache_kernel_128(float* input_k_cache, __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
float* input_v_cache, float* input_v_cache,
float* input_q, float* input_q,
float* input_k, float* input_k,
float* input_v, float* input_v,
int* position_id, int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
float ptr_V[SEQ_UNIT*4];
float ptr_K[SEQ_UNIT*4];
float ptr_Q[4];
float ptr_P[SEQ_UNIT];
float ptr_O[4];
float ptr_max[1];
float ptr_sum[1];
float ptr_max_last[1];
float ptr_sum_last[1];
float ptr_O_last[4];
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
int SEQ_LENGTH = position_id[0] + 1;
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
ptr_max_last[0] = ptr_max[0];
ptr_sum_last[0] = ptr_sum[0];
(float4 &)ptr_O_last[0] = (float4 &)ptr_O[0];
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
}
ptr_K[idx_SEQ_UNIT * 4] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 4];
ptr_K[idx_SEQ_UNIT * 4 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 4 + 1];
ptr_K[idx_SEQ_UNIT * 4 + 2] = ptr_Q[2] * ptr_K[idx_SEQ_UNIT * 4 + 2];
ptr_K[idx_SEQ_UNIT * 4 + 3] = ptr_Q[3] * ptr_K[idx_SEQ_UNIT * 4 + 3];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4], offset);
}
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 4];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 1)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 1)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 2)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 2)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 2)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 3)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 3)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 3)];
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
}
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
ptr_sum[0] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
ptr_O[0] = 0;
ptr_O[1] = 0;
ptr_O[2] = 0;
ptr_O[3] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
}
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4)], ptr_O[0]);
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 1], ptr_O[1]);
ptr_O[2] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 2], ptr_O[2]);
ptr_O[3] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 3], ptr_O[3]);
}
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
ptr_O[2] = (idx_seq == 0) ? ptr_O[2] : ptr_O[2] + ptr_O_last[2] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
ptr_O[3] = (idx_seq == 0) ? ptr_O[3] : ptr_O[3] + ptr_O_last[3] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
}
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O[0];
}
__global__ void _attention_kvcache_kernel_128_sum_only(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
float ptr_V[SEQ_UNIT*4];
float ptr_K[SEQ_UNIT*4];
float ptr_Q[4];
float ptr_P[SEQ_UNIT];
float ptr_O[4];
float ptr_sum[1];
float ptr_sum_last[1];
float ptr_O_last[4];
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
int SEQ_LENGTH = position_id[0] + 1;
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
ptr_sum_last[0] = ptr_sum[0];
(float4 &)ptr_O_last[0] = (float4 &)ptr_O[0];
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
}
ptr_K[idx_SEQ_UNIT * 4] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 4];
ptr_K[idx_SEQ_UNIT * 4 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 4 + 1];
ptr_K[idx_SEQ_UNIT * 4 + 2] = ptr_Q[2] * ptr_K[idx_SEQ_UNIT * 4 + 2];
ptr_K[idx_SEQ_UNIT * 4 + 3] = ptr_Q[3] * ptr_K[idx_SEQ_UNIT * 4 + 3];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4], offset);
}
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 4];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 1)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 1)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 2)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 2)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 2)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 3)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 3)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 3)];
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
}
ptr_sum[0] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : ptr_sum_last[0] + ptr_sum[0];
ptr_O[0] = 0;
ptr_O[1] = 0;
ptr_O[2] = 0;
ptr_O[3] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
}
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4)], ptr_O[0]);
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 1], ptr_O[1]);
ptr_O[2] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 2], ptr_O[2]);
ptr_O[3] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 3], ptr_O[3]);
}
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * ptr_sum_last[0] / ptr_sum[0];
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * ptr_sum_last[0] / ptr_sum[0];
ptr_O[2] = (idx_seq == 0) ? ptr_O[2] : ptr_O[2] + ptr_O_last[2] * ptr_sum_last[0] / ptr_sum[0];
ptr_O[3] = (idx_seq == 0) ? ptr_O[3] : ptr_O[3] + ptr_O_last[3] * ptr_sum_last[0] / ptr_sum[0];
}
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O[0];
}
__global__ void _attention_kvcache_kernel_128_sum_only_cp(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
float ptr_V[SEQ_UNIT*4];
float ptr_K[SEQ_UNIT*4];
float ptr_Q[4];
float ptr_P[SEQ_UNIT];
float ptr_O[4];
float ptr_sum[1] = {0};
float ptr_O_last[4];
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
int SEQ_LENGTH = position_id + 1;
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
(float4 &)ptr_O_last[0] = (float4 &)ptr_O[0];
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
}
ptr_K[idx_SEQ_UNIT * 4] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 4];
ptr_K[idx_SEQ_UNIT * 4 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 4 + 1];
ptr_K[idx_SEQ_UNIT * 4 + 2] = ptr_Q[2] * ptr_K[idx_SEQ_UNIT * 4 + 2];
ptr_K[idx_SEQ_UNIT * 4 + 3] = ptr_Q[3] * ptr_K[idx_SEQ_UNIT * 4 + 3];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4], offset);
}
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 4];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 1)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 1)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 2)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 2)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 2)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 3)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 3)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 3)];
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
ptr_O[0] = 0;
ptr_O[1] = 0;
ptr_O[2] = 0;
ptr_O[3] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
}
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4)], ptr_O[0]);
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 1], ptr_O[1]);
ptr_O[2] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 2], ptr_O[2]);
ptr_O[3] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 3], ptr_O[3]);
}
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0];
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1];
ptr_O[2] = (idx_seq == 0) ? ptr_O[2] : ptr_O[2] + ptr_O_last[2];
ptr_O[3] = (idx_seq == 0) ? ptr_O[3] : ptr_O[3] + ptr_O_last[3];
}
ptr_O[0] = ptr_O[0] / ptr_sum[0];
ptr_O[1] = ptr_O[1] / ptr_sum[0];
ptr_O[2] = ptr_O[2] / ptr_sum[0];
ptr_O[3] = ptr_O[3] / ptr_sum[0];
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O[0];
}
__global__ void _attention_kvcache_kernel_128_sum_only_1(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta, AttentionKVCacheMetadata compMeta,
float* output_O_temp, float* output_O_temp,
float* output_sum_temp) { float* output_sum_temp) {
int seq_length = position_id[0] + 1;
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
if(blockIdx.y >= stride)
return;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE; int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int SEQ_LENGTH = position_id + 1;
int idx_seq = blockIdx.y * SEQ_UNIT; int idx_seq = blockIdx.y * SEQ_UNIT;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1] && idx_seq >= SEQ_LENGTH) if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return; return;
float ptr_V[SEQ_UNIT*4]; float ptr_V[SEQ_UNIT*4];
float ptr_K[SEQ_UNIT*4]; float ptr_K[SEQ_UNIT*4];
float ptr_Q[4]; float ptr_Q[4];
float ptr_P[SEQ_UNIT]; float ptr_P[SEQ_UNIT] = {0};
float ptr_O[4]; float ptr_O[4] = {0};
float ptr_sum[1] = {0}; float ptr_sum[1] = {0};
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)]; (float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]); int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4] (float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; = (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
} }
@ -531,55 +158,33 @@ __global__ void _attention_kvcache_kernel_128_sum_only_1(float* input_k_cache,
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = (float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float4 &)ptr_K[idx_SEQ_UNIT * 4]; (float4 &)ptr_K[idx_SEQ_UNIT * 4];
} }
ptr_K[idx_SEQ_UNIT * 4] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 4];
ptr_K[idx_SEQ_UNIT * 4 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 4 + 1];
ptr_K[idx_SEQ_UNIT * 4 + 2] = ptr_Q[2] * ptr_K[idx_SEQ_UNIT * 4 + 2];
ptr_K[idx_SEQ_UNIT * 4 + 3] = ptr_Q[3] * ptr_K[idx_SEQ_UNIT * 4 + 3];
#pragma unroll #pragma unroll
for (int offset = 16; offset > 0; offset /= 2) { for (int i = 0; i < 4; i ++){
ptr_K[idx_SEQ_UNIT * 4] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4], offset); ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
} }
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 4];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 1)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 1)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 2)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 2)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 2)];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 4) + 3)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 4) + 3)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 4) + 3)];
} }
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0); ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
} }
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]); ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
} }
ptr_O[0] = 0;
ptr_O[1] = 0;
ptr_O[2] = 0;
ptr_O[3] = 0;
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_V[idx_SEQ_UNIT * 4] (float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; = (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
} }
@ -590,31 +195,23 @@ __global__ void _attention_kvcache_kernel_128_sum_only_1(float* input_k_cache,
= (float4 &)ptr_V[idx_SEQ_UNIT * 4]; = (float4 &)ptr_V[idx_SEQ_UNIT * 4];
} }
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4)], ptr_O[0]); #pragma unroll
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 1], ptr_O[1]); for (int i = 0; i < 4; i ++)
ptr_O[2] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 2], ptr_O[2]); ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
ptr_O[3] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4) + 3], ptr_O[3]);
} }
if(gridDim.y == 1){ #pragma unroll
ptr_O[0] /= ptr_sum[0]; for (int i = 0; i < 4; i ++)
ptr_O[1] /= ptr_sum[0]; ptr_O[i] /= ptr_sum[0];
ptr_O[2] /= ptr_sum[0];
ptr_O[3] /= ptr_sum[0];
}
else if(threadIdx.x == 0){
output_sum_temp[blockIdx.y + parallel_idx * gridDim.y] = ptr_sum[0];
}
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * gridDim.y)] = (float4 &)ptr_O[0]; (float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
if(threadIdx.x == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
}
} }
__global__ void _attention_kvcache_kernel_128_sum_only_2(float* input_k_cache, __global__ void _attention_kvcache_kernel_128_2(int* position_id,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int size,
float* output_matmul, float* output_matmul,
AttentionKVCacheMetadata compMeta, AttentionKVCacheMetadata compMeta,
float* output_O_temp, float* output_O_temp,
@ -626,22 +223,24 @@ __global__ void _attention_kvcache_kernel_128_sum_only_2(float* input_k_cache,
float ptr_O[4] = {0}; float ptr_O[4] = {0};
float ptr_O_sum[4] = {0}; float ptr_O_sum[4] = {0};
float ptr_sum = 0; float ptr_sum = 0;
float ptr_sum_temp;
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
#pragma unroll #pragma unroll
for(int i = 0; i < size; i ++){ for(int i = 0; i < size; i ++){
(float4 &)ptr_O[0] (float4 &)ptr_O[0]
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]; = (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
ptr_O_sum[0] += ptr_O[0]; ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
ptr_O_sum[1] += ptr_O[1];
ptr_O_sum[2] += ptr_O[2]; #pragma unroll
ptr_O_sum[3] += ptr_O[3]; for(int k = 0; k < 4; k ++)
ptr_sum += output_sum_temp[i + parallel_idx * size]; ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
ptr_sum += ptr_sum_temp;
} }
ptr_O_sum[0] = ptr_O_sum[0] / ptr_sum; #pragma unroll
ptr_O_sum[1] = ptr_O_sum[1] / ptr_sum; for(int k = 0; k < 4; k ++)
ptr_O_sum[2] = ptr_O_sum[2] / ptr_sum; ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
ptr_O_sum[3] = ptr_O_sum[3] / ptr_sum;
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0]; (float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
@ -654,29 +253,19 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
const AttentionKVCacheMetadata &compMeta, const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp) { float *output_O_temp, float *output_sum_temp) {
IT_ASSERT(compMeta.dimSize[3] == 64 || compMeta.dimSize[3] == 128); IT_ASSERT(compMeta.dimSize[3] == 64 || compMeta.dimSize[3] == 128);
int position_id_h;
cudaMemcpy(&position_id_h, position_id, sizeof(int), cudaMemcpyDeviceToHost);
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
int gridsize_y = (position_id_h + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y); dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
dim3 blockDim(BLOCKSIZE, 1); dim3 blockDim(BLOCKSIZE, 1);
bool needReduce = gridsize_y > 1 ? true : false;
if(compMeta.dimSize[3] == 64) if(compMeta.dimSize[3] == 64)
_attention_kvcache_kernel<<<gridDim, blockDim>>>( _attention_kvcache_kernel_64<<<gridDim.x, blockDim>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta); input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta);
else{ else{
if(!needReduce){ _attention_kvcache_kernel_128_1<<<gridDim, blockDim>>>(
_attention_kvcache_kernel_128_sum_only_1<<<gridDim, blockDim>>>( input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, compMeta, output_O_temp, output_sum_temp);
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id_h, nullptr, compMeta, output_matmul, nullptr); _attention_kvcache_kernel_128_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE>>>(
} position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
else{
_attention_kvcache_kernel_128_sum_only_1<<<gridDim, blockDim>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id_h, nullptr, compMeta, output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_sum_only_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, gridsize_y, output_matmul, compMeta, output_O_temp, output_sum_temp);
}
} }
} }