cache is fp16

This commit is contained in:
xiaonans 2024-03-18 15:51:19 +08:00
parent 80412ae162
commit 1e797d4ffe
1 changed files with 38 additions and 40 deletions

View File

@ -19,7 +19,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
if(blockIdx.y >= stride) if(blockIdx.y >= stride)
return; return;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id_x4 = threadIdx.x % WARP_SIZE * 4;
int group_id = threadIdx.x / WARP_SIZE; int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int idx_seq = blockIdx.y * SEQ_UNIT; int idx_seq = blockIdx.y * SEQ_UNIT;
@ -36,74 +36,72 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
float ptr_sum[1] = {0}; float ptr_sum[1] = {0};
// readin Q // readin Q
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)]; (float4 &)ptr_Q[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]); int common_idx = lane_id_x4 + (parallel_idx * compMeta.stride[1]);
int idx_kv = lane_id_x4 + parallel_idx * compMeta.stride[2];
// Q*K // Q*K
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]);
int idx_SEQ_UNIT_x4 = idx_SEQ_UNIT * 4;
half temp[4];
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_K[idx_SEQ_UNIT * 4] *((int2*)(&temp[0])) = *((int2*)(&((half*)input_k_cache)[idx_kvcache]));
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; for(int i = 0; i < 4; i += 2)
(float2 &)ptr_K[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]);
} }
else{ else{
(float4 &)ptr_K[idx_SEQ_UNIT * 4] (float4 &)ptr_K[idx_SEQ_UNIT_x4] = (float4 &) input_k[idx_kv];
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; for(int i = 0; i < 4; i += 2){
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = *((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_K[idx_SEQ_UNIT_x4 + i])));
(float4 &)ptr_K[idx_SEQ_UNIT * 4]; }
*((int2*)(&((half*)input_k_cache)[idx_kvcache])) = *((int2*)(&temp[0]));
} }
// * V
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
*((int2*)(&temp[0])) = *((int2*)(&((half*)input_v_cache)[idx_kvcache]));
for(int i = 0; i < 4; i += 2)
(float2 &)ptr_V[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]);
}
else{
(float4 &)ptr_V[idx_SEQ_UNIT_x4] = (float4 &) input_v[idx_kv];
for(int i = 0; i < 4; i += 2){
*((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_V[idx_SEQ_UNIT_x4 + i])));
}
*((int2*)(&((half*)input_v_cache)[idx_kvcache])) = *((int2*)(&temp[0]));
}
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i ++){ for (int i = 0; i < 4; i ++){
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i]; ptr_K[idx_SEQ_UNIT_x4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT_x4 + i];
#pragma unroll #pragma unroll
for (int offset = 16; offset > 0; offset /= 2) { for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset); ptr_K[idx_SEQ_UNIT_x4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT_x4 + i], offset);
} }
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i]; ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT_x4 + i];
} }
}
// div sqrt(d) // div sqrt(d)
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0); ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
}
// softmax // softmax
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]); ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
// * V
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
}
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i ++) for (int i = 0; i < 4; i ++)
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]); ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT_x4 + i)], ptr_O[i]);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i ++) for (int i = 0; i < 4; i ++)
ptr_O[i] /= ptr_sum[0]; ptr_O[i] /= ptr_sum[0];
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; (float4 &)output_O_temp[lane_id_x4 + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
if(lane_id == 0){ if(lane_id_x4 == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
} }