forked from jiuyuan/InfiniTensor
inter-block communication is fp16
This commit is contained in:
parent
db053e32a4
commit
d43364ac60
|
@ -2,7 +2,7 @@
|
|||
#include "cuda/cuda_attention_kvcache.h"
|
||||
#define WARP_SIZE 32
|
||||
#define BLOCKSIZE WARP_SIZE*2
|
||||
#define SEQ_UNIT 16
|
||||
#define SEQ_UNIT 8
|
||||
|
||||
// ASSUME SEQ_LEN OF Q IS 1
|
||||
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||
|
@ -12,7 +12,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
|||
float* input_v,
|
||||
int* position_id,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
float* output_O_temp,
|
||||
half* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int seq_length = position_id[0] + 1;
|
||||
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
|
||||
|
@ -29,7 +29,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
|||
|
||||
half ptr_V[4]; // V
|
||||
half ptr_K[4]; // K
|
||||
float ptr_Q[4]; // Q
|
||||
half ptr_Q[4]; // Q
|
||||
float ptr_P[SEQ_UNIT] = {0};
|
||||
|
||||
float ptr_O[4] = {0};
|
||||
|
@ -37,7 +37,10 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
|||
float temp[4];
|
||||
|
||||
// readin Q
|
||||
(float4 &)ptr_Q[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
|
||||
(float4 &)temp[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
*((half2*)(&ptr_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
}
|
||||
int common_idx = lane_id_x4 + (parallel_idx * compMeta.stride[1]);
|
||||
int idx_kv = lane_id_x4 + parallel_idx * compMeta.stride[2];
|
||||
|
||||
|
@ -67,7 +70,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
|||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++){
|
||||
ptr_K[i] = __float2half(ptr_Q[i]) * ptr_K[i];
|
||||
ptr_K[i] = ptr_Q[i] * ptr_K[i];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset);
|
||||
|
@ -92,7 +95,8 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
|||
for (int i = 0; i < 4; i ++)
|
||||
ptr_O[i] /= ptr_sum[0];
|
||||
|
||||
(float4 &)output_O_temp[lane_id_x4 + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
|
||||
for(int i = 0; i < 4; i += 2)
|
||||
(half2 &)output_O_temp[(lane_id_x4 + i) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = __float22half2_rn((float2 &)ptr_O[i]);
|
||||
if(lane_id_x4 == 0){
|
||||
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
|
||||
}
|
||||
|
@ -102,9 +106,9 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
|||
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
||||
float* output_matmul,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
float* output_O_temp,
|
||||
half* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int lane_id_x4 = threadIdx.x % WARP_SIZE * 4;
|
||||
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
|
||||
|
@ -117,8 +121,9 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|||
if(size > 1){
|
||||
#pragma unroll
|
||||
for(int i = 0; i < size; i ++){
|
||||
(float4 &)ptr_O[0]
|
||||
= (float4 &)output_O_temp[lane_id_x4 + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||
for(int j = 0; j < 4; j += 2)
|
||||
(float2 &)ptr_O[j]
|
||||
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + j*32) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]);
|
||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||
|
||||
#pragma unroll
|
||||
|
@ -130,14 +135,15 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||
|
||||
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||
|
||||
for(int j = 0; j < 4; j += 2)
|
||||
(float2 &)output_matmul[lane_id_x2 + j*32 + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O_sum[j];
|
||||
}
|
||||
else{
|
||||
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])]
|
||||
= (float4 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
|
||||
for(int i = 0; i < 4; i += 2)
|
||||
(float2 &)output_matmul[(lane_id_x2 + i*32) + (parallel_idx * compMeta.dimSize[3])]
|
||||
= __half22float2((half2 &)output_O_temp[(lane_id_x2 + i*32) + parallel_idx * compMeta.dimSize[3]]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -156,12 +162,12 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
|||
_attention_kvcache_kernel_128_1
|
||||
<<<gridDim, blockDim, 0, CUDAStream::getCurrentStream()>>>
|
||||
(input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
|
||||
compMeta, output_O_temp, output_sum_temp);
|
||||
compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
|
||||
_attention_kvcache_kernel_128_2
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1], WARP_SIZE,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||
(position_id, output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue