forked from jiuyuan/InfiniTensor
kv register is fp16
This commit is contained in:
parent
1e797d4ffe
commit
db053e32a4
|
@ -27,13 +27,14 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||||
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
||||||
return;
|
return;
|
||||||
|
|
||||||
float ptr_V[SEQ_UNIT*4]; // V
|
half ptr_V[4]; // V
|
||||||
float ptr_K[SEQ_UNIT*4]; // K
|
half ptr_K[4]; // K
|
||||||
float ptr_Q[4]; // Q
|
float ptr_Q[4]; // Q
|
||||||
float ptr_P[SEQ_UNIT] = {0};
|
float ptr_P[SEQ_UNIT] = {0};
|
||||||
|
|
||||||
float ptr_O[4] = {0};
|
float ptr_O[4] = {0};
|
||||||
float ptr_sum[1] = {0};
|
float ptr_sum[1] = {0};
|
||||||
|
float temp[4];
|
||||||
|
|
||||||
// readin Q
|
// readin Q
|
||||||
(float4 &)ptr_Q[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
|
(float4 &)ptr_Q[0] = (float4 &)input_q[lane_id_x4 + (parallel_idx * 128)];
|
||||||
|
@ -44,43 +45,34 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||||
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]);
|
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]);
|
||||||
int idx_SEQ_UNIT_x4 = idx_SEQ_UNIT * 4;
|
|
||||||
half temp[4];
|
|
||||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||||
*((int2*)(&temp[0])) = *((int2*)(&((half*)input_k_cache)[idx_kvcache]));
|
*((int2*)(&ptr_K[0])) = *((int2*)(&((half*)input_k_cache)[idx_kvcache]));
|
||||||
for(int i = 0; i < 4; i += 2)
|
|
||||||
(float2 &)ptr_K[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]);
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
(float4 &)ptr_K[idx_SEQ_UNIT_x4] = (float4 &) input_k[idx_kv];
|
(float4 &)temp[0] = (float4 &) input_k[idx_kv];
|
||||||
for(int i = 0; i < 4; i += 2){
|
for(int i = 0; i < 4; i += 2)
|
||||||
*((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_K[idx_SEQ_UNIT_x4 + i])));
|
*((half2*)(&ptr_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||||
|
*((int2*)(&((half*)input_k_cache)[idx_kvcache])) = *((int2*)(&ptr_K[0]));
|
||||||
}
|
}
|
||||||
*((int2*)(&((half*)input_k_cache)[idx_kvcache])) = *((int2*)(&temp[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
// * V
|
// * V
|
||||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||||
*((int2*)(&temp[0])) = *((int2*)(&((half*)input_v_cache)[idx_kvcache]));
|
*((int2*)(&ptr_V[0])) = *((int2*)(&((half*)input_v_cache)[idx_kvcache]));
|
||||||
for(int i = 0; i < 4; i += 2)
|
|
||||||
(float2 &)ptr_V[idx_SEQ_UNIT_x4 + i] = __half22float2((half2 &)temp[i]);
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
(float4 &)ptr_V[idx_SEQ_UNIT_x4] = (float4 &) input_v[idx_kv];
|
(float4 &)temp[0] = (float4 &) input_v[idx_kv];
|
||||||
for(int i = 0; i < 4; i += 2){
|
for(int i = 0; i < 4; i += 2)
|
||||||
*((half2*)(&temp[i])) = __float22half2_rn(*((float2*)(&ptr_V[idx_SEQ_UNIT_x4 + i])));
|
*((half2*)(&ptr_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||||
}
|
*((int2*)(&((half*)input_v_cache)[idx_kvcache])) = *((int2*)(&ptr_V[0]));
|
||||||
*((int2*)(&((half*)input_v_cache)[idx_kvcache])) = *((int2*)(&temp[0]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i ++){
|
for (int i = 0; i < 4; i ++){
|
||||||
ptr_K[idx_SEQ_UNIT_x4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT_x4 + i];
|
ptr_K[i] = __float2half(ptr_Q[i]) * ptr_K[i];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset /= 2) {
|
for (int offset = 16; offset > 0; offset /= 2) {
|
||||||
ptr_K[idx_SEQ_UNIT_x4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT_x4 + i], offset);
|
ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset);
|
||||||
}
|
}
|
||||||
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT_x4 + i];
|
ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// div sqrt(d)
|
// div sqrt(d)
|
||||||
|
@ -93,7 +85,7 @@ __global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; i ++)
|
for (int i = 0; i < 4; i ++)
|
||||||
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT_x4 + i)], ptr_O[i]);
|
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], __half2float(ptr_V[i]), ptr_O[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
Loading…
Reference in New Issue