diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index 3fb0de11..bdf29121 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -19,7 +19,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, if(blockIdx.y >= stride) return; - int lane_id = threadIdx.x % WARP_SIZE; + int lane_id_x4 = threadIdx.x % WARP_SIZE * 4; int group_id = threadIdx.x / WARP_SIZE; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; int idx_seq = blockIdx.y * SEQ_UNIT; @@ -36,74 +36,72 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, float ptr_sum[1] = {0}; // readin Q - (float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)]; - int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]); - + (float4 &)ptr_Q[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)]; + int common_idx = lane_id_x4 + (parallel_idx * compMeta.stride[1]); + int idx_kv = lane_id_x4 + parallel_idx * compMeta.stride[2]; + // Q*K #pragma unroll for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ - (float4 &)ptr_K[idx_SEQ_UNIT * 4] - = (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]); + int idx_SEQ_UNIT_x4 = idx_SEQ_UNIT * 4; + half temp[4]; + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + *((int2*)(&temp[0])) = *((int2*)(&((half*)input_k_cache)[idx_kvcache])); + for(int i = 0; i < 4; i += 2) + (float2 &)ptr_K[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]); } else{ - (float4 &)ptr_K[idx_SEQ_UNIT * 4] - = (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; - (float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = - (float4 &)ptr_K[idx_SEQ_UNIT * 4]; + (float4 &)ptr_K[idx_SEQ_UNIT_x4] = (float4 &) input_k[idx_kv]; + for(int i = 0; i < 4; i += 2){ + *((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_K[idx_SEQ_UNIT_x4 + i]))); + } + *((int2*)(&((half*)input_k_cache)[idx_kvcache])) = *((int2*)(&temp[0])); } + // * V + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + *((int2*)(&temp[0])) = *((int2*)(&((half*)input_v_cache)[idx_kvcache])); + for(int i = 0; i < 4; i += 2) + (float2 &)ptr_V[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]); + } + else{ + (float4 &)ptr_V[idx_SEQ_UNIT_x4] = (float4 &) input_v[idx_kv]; + for(int i = 0; i < 4; i += 2){ + *((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_V[idx_SEQ_UNIT_x4 + i]))); + } + *((int2*)(&((half*)input_v_cache)[idx_kvcache])) = *((int2*)(&temp[0])); + } #pragma unroll for (int i = 0; i < 4; i ++){ - ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i]; + ptr_K[idx_SEQ_UNIT_x4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT_x4 + i]; #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { - ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset); + ptr_K[idx_SEQ_UNIT_x4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT_x4 + i], offset); } - ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i]; + ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT_x4 + i]; } - } - // div sqrt(d) - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + // div sqrt(d) ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); ptr_P[idx_SEQ_UNIT] /= sqrt(128.0); - } - // softmax - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + // softmax ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]); ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; - } - - // * V - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ - (float4 &)ptr_V[idx_SEQ_UNIT * 4] - = (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; - } - else{ - (float4 &)ptr_V[idx_SEQ_UNIT * 4] - = (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; - (float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] - = (float4 &)ptr_V[idx_SEQ_UNIT * 4]; - } #pragma unroll for (int i = 0; i < 4; i ++) - ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]); + ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT_x4 + i)], ptr_O[i]); } #pragma unroll for (int i = 0; i < 4; i ++) ptr_O[i] /= ptr_sum[0]; - (float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; - if(lane_id == 0){ + (float4 &)output_O_temp[lane_id_x4 + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; + if(lane_id_x4 == 0){ output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; }