This commit is contained in:
xiaonans 2024-03-21 10:17:06 +08:00
parent fc3d38f80e
commit 0740d26f43
1 changed files with 39 additions and 27 deletions

View File

@ -20,7 +20,7 @@ __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
if(blockIdx.y >= stride) if(blockIdx.y >= stride)
return; return;
int lane_id_x4 = threadIdx.x % WARP_SIZE * 4; int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
int group_id = threadIdx.x / WARP_SIZE; int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int idx_seq = blockIdx.y * SEQ_UNIT; int idx_seq = blockIdx.y * SEQ_UNIT;
@ -38,64 +38,75 @@ __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
float temp[4]; float temp[4];
bool is_fp16 = sizeof(T) == 2 ? true : false; bool is_fp16 = sizeof(T) == 2 ? true : false;
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.stride[2];
// readin Q // readin Q
if(!is_fp16){ if(!is_fp16){
(float4 &)temp[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
for(int i = 0; i < 4; i += 2){ for(int i = 0; i < 4; i += 2){
(float2 &)temp[i] = (float2 &)input_q[idx_qkv + i*WARP_SIZE];
*((half2*)(&ptr_Q[i])) = __float22half2_rn(*((float2*)(&temp[i]))); *((half2*)(&ptr_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
} }
} }
else{ else{
for(int i = 0; i < 4; i += 2){ for(int i = 0; i < 4; i += 2){
(half2 &)ptr_Q[i] = (half2 &)input_q[lane_id_x4 + i + (parallel_idx * 128)]; (half2 &)ptr_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
} }
} }
int common_idx = lane_id_x4 + (parallel_idx * compMeta.stride[1]); int common_idx = lane_id_x2 + (parallel_idx * compMeta.stride[1]);
int idx_kv = lane_id_x4 + parallel_idx * compMeta.stride[2];
// Q*K // Q*K
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]); int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]);
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
*((int2*)(&ptr_K[0])) = *((int2*)(&((half*)input_k_cache)[idx_kvcache])); for(int i = 0; i < 4; i += 2){
*((half2*)(&ptr_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
}
} }
else{ else{
if(!is_fp16){ if(!is_fp16){
(float4 &)temp[0] = (float4 &) input_k[idx_kv]; for(int i = 0; i < 4; i += 2){
for(int i = 0; i < 4; i += 2) (float2 &)temp[i] = (float2 &) input_k[idx_qkv + i*WARP_SIZE];
*((half2*)(&ptr_K[i])) = __float22half2_rn(*((float2*)(&temp[i]))); *((half2*)(&ptr_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
}
} }
else{ else{
for(int i = 0; i < 4; i += 2){ for(int i = 0; i < 4; i += 2){
(half2 &)ptr_K[i] = (half2 &)input_k[lane_id_x4 + i + (parallel_idx * 128)]; (half2 &)ptr_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
} }
} }
*((int2*)(&((half*)input_k_cache)[idx_kvcache])) = *((int2*)(&ptr_K[0])); for(int i = 0; i < 4; i += 2){
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&ptr_K[i]));
}
} }
// * V // * V
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
*((int2*)(&ptr_V[0])) = *((int2*)(&((half*)input_v_cache)[idx_kvcache])); for(int i = 0; i < 4; i += 2){
*((half2*)(&ptr_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
}
} }
else{ else{
if(!is_fp16){ if(!is_fp16){
(float4 &)temp[0] = (float4 &) input_v[idx_kv]; for(int i = 0; i < 4; i += 2){
for(int i = 0; i < 4; i += 2) (float2 &)temp[i] = (float2 &) input_v[idx_qkv + i*WARP_SIZE];
*((half2*)(&ptr_V[i])) = __float22half2_rn(*((float2*)(&temp[i]))); *((half2*)(&ptr_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
}
} }
else{ else{
for(int i = 0; i < 4; i += 2){ for(int i = 0; i < 4; i += 2){
(half2 &)ptr_V[i] = (half2 &)input_v[lane_id_x4 + i + (parallel_idx * 128)]; (half2 &)ptr_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
} }
} }
*((int2*)(&((half*)input_v_cache)[idx_kvcache])) = *((int2*)(&ptr_V[0])); for(int i = 0; i < 4; i += 2){
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&ptr_V[i]));
}
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i ++){ for (int i = 0; i < 4; i ++){
ptr_K[i] = ptr_Q[i] * ptr_K[i]; ptr_K[i] = ptr_Q[i] * ptr_K[i];
#pragma unroll #pragma unroll
for (int offset = 16; offset > 0; offset /= 2) { for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset); ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset);
} }
ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]); ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]);
@ -119,8 +130,8 @@ __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
ptr_O[i] /= ptr_sum[0]; ptr_O[i] /= ptr_sum[0];
for(int i = 0; i < 4; i += 2) for(int i = 0; i < 4; i += 2)
(half2 &)output_O_temp[(lane_id_x4 + i) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = __float22half2_rn((float2 &)ptr_O[i]); (half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = __float22half2_rn((float2 &)ptr_O[i]);
if(lane_id_x4 == 0){ if(lane_id_x2 == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
} }
@ -148,7 +159,7 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
for(int i = 0; i < size; i ++){ for(int i = 0; i < size; i ++){
for(int j = 0; j < 4; j += 2) for(int j = 0; j < 4; j += 2)
(float2 &)ptr_O[j] (float2 &)ptr_O[j]
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + j*32) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]); = __half22float2((half2 &)output_O_temp[(lane_id_x2 + j*WARP_SIZE) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]);
ptr_sum_temp = output_sum_temp[i + parallel_idx * size]; ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
#pragma unroll #pragma unroll
@ -163,23 +174,23 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
if(!is_fp16){ if(!is_fp16){
for(int j = 0; j < 4; j += 2) for(int j = 0; j < 4; j += 2)
(float2 &)output_matmul[lane_id_x2 + j*32 + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O_sum[j]; (float2 &)output_matmul[lane_id_x2 + j*WARP_SIZE + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O_sum[j];
} }
else{ else{
for(int j = 0; j < 4; j += 2) for(int j = 0; j < 4; j += 2)
(half2 &)output_matmul[lane_id_x2 + j*32 + (parallel_idx * compMeta.dimSize[3])] = __float22half2_rn((float2 &)ptr_O_sum[j]); (half2 &)output_matmul[lane_id_x2 + j*WARP_SIZE + (parallel_idx * compMeta.dimSize[3])] = __float22half2_rn((float2 &)ptr_O_sum[j]);
} }
} }
else{ else{
if(!is_fp16){ if(!is_fp16){
for(int i = 0; i < 4; i += 2) for(int i = 0; i < 4; i += 2)
(float2 &)output_matmul[(lane_id_x2 + i*32) + (parallel_idx * compMeta.dimSize[3])] (float2 &)output_matmul[(lane_id_x2 + i*WARP_SIZE) + (parallel_idx * compMeta.dimSize[3])]
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + i*32) + parallel_idx * compMeta.dimSize[3]]); = __half22float2((half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + parallel_idx * compMeta.dimSize[3]]);
} }
else{ else{
for(int i = 0; i < 4; i += 2) for(int i = 0; i < 4; i += 2)
(half2 &)output_matmul[(lane_id_x2 + i*32) + (parallel_idx * compMeta.dimSize[3])] (half2 &)output_matmul[(lane_id_x2 + i*WARP_SIZE) + (parallel_idx * compMeta.dimSize[3])]
= (half2 &)output_O_temp[(lane_id_x2 + i*32) + parallel_idx * compMeta.dimSize[3]]; = (half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + parallel_idx * compMeta.dimSize[3]];
} }
} }
} }
@ -219,6 +230,7 @@ void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cach
0, CUDAStream::getCurrentStream()>>> 0, CUDAStream::getCurrentStream()>>>
(position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp); (position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
} }
} }
} // namespace infini } // namespace infini