forked from jiuyuan/InfiniTensor
fix bugs when blocksize==64
This commit is contained in:
parent
83be7fa373
commit
80412ae162
|
@ -1,7 +1,7 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_attention_kvcache.h"
|
||||
#define WARP_SIZE 32
|
||||
#define BLOCKSIZE WARP_SIZE
|
||||
#define BLOCKSIZE WARP_SIZE*2
|
||||
#define SEQ_UNIT 16
|
||||
|
||||
// ASSUME SEQ_LEN OF Q IS 1
|
||||
|
@ -114,7 +114,7 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|||
AttentionKVCacheMetadata compMeta,
|
||||
float* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int lane_id_x4 = threadIdx.x % WARP_SIZE * 4;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
|
||||
|
@ -124,23 +124,29 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|||
float ptr_sum_temp;
|
||||
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < size; i ++){
|
||||
(float4 &)ptr_O[0]
|
||||
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||
if(size > 1){
|
||||
#pragma unroll
|
||||
for(int i = 0; i < size; i ++){
|
||||
(float4 &)ptr_O[0]
|
||||
= (float4 &)output_O_temp[lane_id_x4 + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
|
||||
ptr_sum += ptr_sum_temp;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
|
||||
ptr_sum += ptr_sum_temp;
|
||||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||
|
||||
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||
}
|
||||
else{
|
||||
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])]
|
||||
= (float4 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||
|
||||
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||
|
||||
}
|
||||
|
||||
|
@ -163,7 +169,7 @@ void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
|||
compMeta, output_O_temp, output_sum_temp);
|
||||
|
||||
_attention_kvcache_kernel_128_2
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE,
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1], WARP_SIZE,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue