inter-block communication is fp16

This commit is contained in:
xiaonans 2024-03-19 11:21:14 +08:00
parent db053e32a4
commit d43364ac60
1 changed files with 23 additions and 17 deletions

View File

@ -2,7 +2,7 @@
#include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE*2
#define SEQ_UNIT 16
#define SEQ_UNIT 8
// ASSUME SEQ_LEN OF Q IS 1
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
@ -12,7 +12,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
float* input_v,
int* position_id,
AttentionKVCacheMetadata compMeta,
float* output_O_temp,
half* output_O_temp,
float* output_sum_temp) {
int seq_length = position_id[0] + 1;
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
@ -29,7 +29,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
half ptr_V[4]; // V
half ptr_K[4]; // K
float ptr_Q[4]; // Q
half ptr_Q[4]; // Q
float ptr_P[SEQ_UNIT] = {0};
float ptr_O[4] = {0};
@ -37,7 +37,10 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
float temp[4];
// readin Q
(float4 &)ptr_Q[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
(float4 &)temp[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
for(int i = 0; i < 4; i += 2){
*((half2*)(&ptr_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
}
int common_idx = lane_id_x4 + (parallel_idx * compMeta.stride[1]);
int idx_kv = lane_id_x4 + parallel_idx * compMeta.stride[2];
@ -67,7 +70,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
#pragma unroll
for (int i = 0; i < 4; i ++){
ptr_K[i] = __float2half(ptr_Q[i]) * ptr_K[i];
ptr_K[i] = ptr_Q[i] * ptr_K[i];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset);
@ -92,7 +95,8 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
for (int i = 0; i < 4; i ++)
ptr_O[i] /= ptr_sum[0];
(float4 &)output_O_temp[lane_id_x4 + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
for(int i = 0; i < 4; i += 2)
(half2 &)output_O_temp[(lane_id_x4 + i) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = __float22half2_rn((float2 &)ptr_O[i]);
if(lane_id_x4 == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
}
@ -102,9 +106,9 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta,
float* output_O_temp,
half* output_O_temp,
float* output_sum_temp) {
int lane_id_x4 = threadIdx.x % WARP_SIZE * 4;
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
@ -117,8 +121,9 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
if(size > 1){
#pragma unroll
for(int i = 0; i < size; i ++){
(float4 &)ptr_O[0]
= (float4 &)output_O_temp[lane_id_x4 + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
for(int j = 0; j < 4; j += 2)
(float2 &)ptr_O[j]
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + j*32) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]);
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
#pragma unroll
@ -130,14 +135,15 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
for(int j = 0; j < 4; j += 2)
(float2 &)output_matmul[lane_id_x2 + j*32 + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O_sum[j];
}
else{
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])]
= (float4 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
for(int i = 0; i < 4; i += 2)
(float2 &)output_matmul[(lane_id_x2 + i*32) + (parallel_idx * compMeta.dimSize[3])]
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + i*32) + parallel_idx * compMeta.dimSize[3]]);
}
}
@ -156,12 +162,12 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
_attention_kvcache_kernel_128_1
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
compMeta, output_O_temp, output_sum_temp);
compMeta, (half*)output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2
<<<compMeta.dimSize[0]*compMeta.dimSize[1], WARP_SIZE,
0, CUDAStream::getCurrentStream()>>>
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
(position_id, output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
}
} // namespace infini