diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index bdf29121..9526668c 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -27,13 +27,14 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) return; - float ptr_V[SEQ_UNIT*4]; // V - float ptr_K[SEQ_UNIT*4]; // K + half ptr_V[4]; // V + half ptr_K[4]; // K float ptr_Q[4]; // Q float ptr_P[SEQ_UNIT] = {0}; float ptr_O[4] = {0}; float ptr_sum[1] = {0}; + float temp[4]; // readin Q (float4 &)ptr_Q[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)]; @@ -44,43 +45,34 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, #pragma unroll for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]); - int idx_SEQ_UNIT_x4 = idx_SEQ_UNIT * 4; - half temp[4]; if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ - *((int2*)(&temp[0])) = *((int2*)(&((half*)input_k_cache)[idx_kvcache])); - for(int i = 0; i < 4; i += 2) - (float2 &)ptr_K[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]); + *((int2*)(&ptr_K[0])) = *((int2*)(&((half*)input_k_cache)[idx_kvcache])); } else{ - (float4 &)ptr_K[idx_SEQ_UNIT_x4] = (float4 &) input_k[idx_kv]; - for(int i = 0; i < 4; i += 2){ - *((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_K[idx_SEQ_UNIT_x4 + i]))); - } - *((int2*)(&((half*)input_k_cache)[idx_kvcache])) = *((int2*)(&temp[0])); + (float4 &)temp[0] = (float4 &) input_k[idx_kv]; + for(int i = 0; i < 4; i += 2) + *((half2*)(&ptr_K[i])) = __float22half2_rn(*((float2*)(&temp[i]))); + *((int2*)(&((half*)input_k_cache)[idx_kvcache])) = *((int2*)(&ptr_K[0])); } - // * V if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ - *((int2*)(&temp[0])) = *((int2*)(&((half*)input_v_cache)[idx_kvcache])); - for(int i = 0; i < 4; i += 2) - (float2 &)ptr_V[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]); + *((int2*)(&ptr_V[0])) = *((int2*)(&((half*)input_v_cache)[idx_kvcache])); } else{ - (float4 &)ptr_V[idx_SEQ_UNIT_x4] = (float4 &) input_v[idx_kv]; - for(int i = 0; i < 4; i += 2){ - *((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_V[idx_SEQ_UNIT_x4 + i]))); - } - *((int2*)(&((half*)input_v_cache)[idx_kvcache])) = *((int2*)(&temp[0])); + (float4 &)temp[0] = (float4 &) input_v[idx_kv]; + for(int i = 0; i < 4; i += 2) + *((half2*)(&ptr_V[i])) = __float22half2_rn(*((float2*)(&temp[i]))); + *((int2*)(&((half*)input_v_cache)[idx_kvcache])) = *((int2*)(&ptr_V[0])); } #pragma unroll for (int i = 0; i < 4; i ++){ - ptr_K[idx_SEQ_UNIT_x4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT_x4 + i]; + ptr_K[i] = __float2half(ptr_Q[i]) * ptr_K[i]; #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { - ptr_K[idx_SEQ_UNIT_x4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT_x4 + i], offset); + ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset); } - ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT_x4 + i]; + ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]); } // div sqrt(d) @@ -93,7 +85,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, #pragma unroll for (int i = 0; i < 4; i ++) - ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT_x4 + i)], ptr_O[i]); + ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], __half2float(ptr_V[i]), ptr_O[i]); } #pragma unroll