forked from jiuyuan/InfiniTensor
accelerate cuda attention
This commit is contained in:
parent
4bdd33522b
commit
eb3a2d123d
|
@ -142,28 +142,30 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|||
AttentionKVCacheMetadata compMeta,
|
||||
half* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
|
||||
int lane_id_x4 = threadIdx.x % WARP_SIZE * 4;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_O_sum[4] = {0};
|
||||
float ptr_sum = 0;
|
||||
float ptr_sum_temp;
|
||||
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
||||
bool is_fp16 = sizeof(T) == 2 ? true : false;
|
||||
|
||||
if(size > 1){
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_O_sum[4] = {0};
|
||||
float ptr_sum = 0;
|
||||
float ptr_sum_temp;
|
||||
half temp_half[4];
|
||||
#pragma unroll
|
||||
for(int i = 0; i < size; i ++){
|
||||
(float2 &)temp_half[0]
|
||||
= (float2 &)output_O_temp[lane_id_x4 + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||
for(int j = 0; j < 4; j += 2)
|
||||
(float2 &)ptr_O[j]
|
||||
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + j*WARP_SIZE) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]);
|
||||
(float2 &)ptr_O[j] = __half22float2((half2 &)temp_half[j]);
|
||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
|
||||
ptr_O_sum[k] = fma(ptr_O[k], ptr_sum_temp, ptr_O_sum[k]);
|
||||
ptr_sum += ptr_sum_temp;
|
||||
}
|
||||
|
||||
|
@ -172,24 +174,27 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||
|
||||
if(!is_fp16){
|
||||
for(int j = 0; j < 4; j += 2)
|
||||
(float2 &)output_matmul[lane_id_x2 + j*WARP_SIZE + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O_sum[j];
|
||||
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||
}
|
||||
else{
|
||||
for(int j = 0; j < 4; j += 2)
|
||||
(half2 &)output_matmul[lane_id_x2 + j*WARP_SIZE + (parallel_idx * compMeta.dimSize[3])] = __float22half2_rn((float2 &)ptr_O_sum[j]);
|
||||
(half2 &)temp_half[j] = __float22half2_rn((float2 &)ptr_O_sum[j]);
|
||||
(float2 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float2 &)temp_half[0];
|
||||
}
|
||||
}
|
||||
else{
|
||||
half temp_half[4];
|
||||
float temp_float[4];
|
||||
if(!is_fp16){
|
||||
(float2 &)temp_half[0]
|
||||
= (float2 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
|
||||
for(int i = 0; i < 4; i += 2)
|
||||
(float2 &)output_matmul[(lane_id_x2 + i*WARP_SIZE) + (parallel_idx * compMeta.dimSize[3])]
|
||||
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + parallel_idx * compMeta.dimSize[3]]);
|
||||
(float2 &)temp_float[i] = __half22float2((half2 &)temp_half[i]);
|
||||
(float4 &)output_matmul[lane_id_x4 + parallel_idx * compMeta.dimSize[3]] = (float4 &)temp_float[0];
|
||||
}
|
||||
else{
|
||||
for(int i = 0; i < 4; i += 2)
|
||||
(half2 &)output_matmul[(lane_id_x2 + i*WARP_SIZE) + (parallel_idx * compMeta.dimSize[3])]
|
||||
= (half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + parallel_idx * compMeta.dimSize[3]];
|
||||
(float2 &)output_matmul[lane_id_x4 + parallel_idx * compMeta.dimSize[3]]
|
||||
= (float2 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue