forked from jiuyuan/InfiniTensor
rope and attention ops support multiple batchs/sequences.
This commit is contained in:
parent
eb3a2d123d
commit
c01e64db50
|
@ -3,14 +3,16 @@
|
|||
#include <cstdio>
|
||||
|
||||
struct AttentionKVCacheMetadata {
|
||||
int dimSize[4];
|
||||
int stride[4];
|
||||
int head_dim;
|
||||
int num_heads;
|
||||
int num_seqs;
|
||||
int max_kv_seqlen;
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void attention_kvcache_kernel(int dType, void *input_k_cache,
|
||||
void *input_v_cache, void *input_q, void *input_k,
|
||||
void *input_v, int *position_id,
|
||||
void *input_v, int64_t *position_id,
|
||||
void *output_matmul,
|
||||
const AttentionKVCacheMetadata &compMeta,
|
||||
float *output_O_temp, float *output_sum_temp);
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
void rope_kernel(int dType, int *pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride,
|
||||
int pos_stride);
|
||||
void rope_kernel(int dType, int64_t *pos, void *input, void *output,
|
||||
int dim_model, int dim_head, int batchsize, int pos_stride);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -20,7 +20,7 @@ class RoPEObj : public OperatorObj {
|
|||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
DataType getDType() const { return getInputs(1)->getDType(); }
|
||||
|
||||
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override {
|
||||
return {inputs[1]->getDType()};
|
||||
};
|
||||
|
|
|
@ -208,8 +208,9 @@ class OnnxStub:
|
|||
op[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
if to_array(data[node.input[1]]).dtype == np.float32 \
|
||||
and type(runtime) == backend.CudaRuntime \
|
||||
if node.input[1] in data.keys() \
|
||||
and to_array(data[node.input[1]]).dtype == np.float32 \
|
||||
and 'cuda_runtime' in dir(backend) \
|
||||
and tensors[node.input[0]].shape()[0] == 1 \
|
||||
and tensors[node.input[0]].shape()[1] == 1 \
|
||||
and len(tensors[node.input[1]].shape()) == 2 \
|
||||
|
|
|
@ -7,15 +7,19 @@ namespace infini {
|
|||
|
||||
class AttentionKVCacheCompute {
|
||||
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
|
||||
Tensor tensor) const {
|
||||
int nDims = tensor->getRank();
|
||||
auto strides = tensor->getStride();
|
||||
Tensor input_v_cache,
|
||||
Tensor position_id) const {
|
||||
int nDims = input_v_cache->getRank();
|
||||
auto strides = input_v_cache->getStride();
|
||||
IT_ASSERT(nDims == 4);
|
||||
IT_ASSERT(strides.size() == (size_t)nDims);
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
metadata.dimSize[i] = tensor->getDims().at(i);
|
||||
metadata.stride[i] = strides.at(i);
|
||||
int dim_position_id = position_id->getRank();
|
||||
metadata.num_seqs = 1;
|
||||
for (int i = 0; i < dim_position_id; i++) {
|
||||
metadata.num_seqs *= position_id->getDims().at(i);
|
||||
}
|
||||
metadata.head_dim = input_v_cache->getDims().at(3);
|
||||
metadata.num_heads = input_v_cache->getDims().at(1);
|
||||
metadata.max_kv_seqlen = input_v_cache->getDims().at(2);
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -24,17 +28,16 @@ class AttentionKVCacheCompute {
|
|||
Tensor position_id, Tensor output_matmul,
|
||||
CudaPtr p_workspace) const {
|
||||
AttentionKVCacheMetadata metadata;
|
||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||
initAttentionKVCacheMetadata(metadata, input_v_cache, position_id);
|
||||
|
||||
attention_kvcache_kernel(dType, input_k_cache->getRawDataPtr<void *>(),
|
||||
input_v_cache->getRawDataPtr<void *>(),
|
||||
input_q->getRawDataPtr<void *>(),
|
||||
input_k->getRawDataPtr<void *>(),
|
||||
input_v->getRawDataPtr<void *>(),
|
||||
position_id->getRawDataPtr<int *>(),
|
||||
output_matmul->getRawDataPtr<void *>(),
|
||||
metadata, (float *)p_workspace,
|
||||
(float *)(p_workspace + (1ll << 30)));
|
||||
attention_kvcache_kernel(
|
||||
dType, input_k_cache->getRawDataPtr<void *>(),
|
||||
input_v_cache->getRawDataPtr<void *>(),
|
||||
input_q->getRawDataPtr<void *>(), input_k->getRawDataPtr<void *>(),
|
||||
input_v->getRawDataPtr<void *>(),
|
||||
position_id->getRawDataPtr<int64_t *>(),
|
||||
output_matmul->getRawDataPtr<void *>(), metadata,
|
||||
(float *)p_workspace, (float *)(p_workspace + (1ll << 30)));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -44,15 +47,15 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
|||
const RuntimeObj *_context) const override {
|
||||
auto op = as<AttentionKVCacheObj>(_op);
|
||||
int dType = op->getDType().getIndex();
|
||||
IT_ASSERT(dType == 1 || dType == 10);
|
||||
int position_idx_dtype = op->getInputs()[5]->getDTypeIndex();
|
||||
IT_ASSERT(dType == 1 || dType == 10 || position_idx_dtype == 7);
|
||||
|
||||
size_t workspaceSize = 2ll << 30;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
|
||||
do_compute(dType, op->getInputs()[0], op->getInputs()[1],
|
||||
op->getInputs()[2], op->getInputs()[3],
|
||||
op->getInputs()[4], op->getInputs()[5],
|
||||
op->getOutputs()[0], idxWsData);
|
||||
op->getInputs()[2], op->getInputs()[3], op->getInputs()[4],
|
||||
op->getInputs()[5], op->getOutputs()[0], idxWsData);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_attention_kvcache.h"
|
||||
#define WARP_SIZE 32
|
||||
#define BLOCKSIZE WARP_SIZE*2
|
||||
#define SEQ_UNIT 8
|
||||
#define SEQ_UNIT 16
|
||||
#define BLOCKSIZE_2 WARP_SIZE*4
|
||||
#define MAX_PARTITION 1024
|
||||
|
||||
// ASSUME SEQ_LEN OF Q IS 1
|
||||
template <class T>
|
||||
|
@ -11,190 +12,186 @@ __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
|
|||
T* input_q,
|
||||
T* input_k,
|
||||
T* input_v,
|
||||
int* position_id,
|
||||
int64_t* position_id,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
half* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int seq_length = position_id[0] + 1;
|
||||
int seq_length = position_id[blockIdx.y] + 1;
|
||||
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
|
||||
if(blockIdx.y >= stride)
|
||||
if(blockIdx.z >= stride)
|
||||
return;
|
||||
|
||||
int lane_id_x2 = threadIdx.x % WARP_SIZE * 2;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
int idx_seq = blockIdx.y * SEQ_UNIT;
|
||||
int parallel_idx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
|
||||
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
||||
return;
|
||||
int idx_seq = blockIdx.z * SEQ_UNIT;
|
||||
|
||||
half ptr_V[4]; // V
|
||||
half ptr_K[4]; // K
|
||||
half ptr_Q[4]; // Q
|
||||
float ptr_P[SEQ_UNIT] = {0};
|
||||
half reg_V[4];
|
||||
half reg_K[4];
|
||||
half reg_Q[4];
|
||||
float reg_P;
|
||||
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_sum[1] = {0};
|
||||
float reg_O[4] = {0};
|
||||
float reg_sum = 0;
|
||||
float temp[4];
|
||||
bool is_fp16 = sizeof(T) == 2 ? true : false;
|
||||
|
||||
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.stride[2];
|
||||
int idx_qkv = lane_id_x2 + parallel_idx * compMeta.head_dim;
|
||||
|
||||
// readin Q
|
||||
if(!is_fp16){
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(float2 &)temp[i] = (float2 &)input_q[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(&ptr_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
*((half2*)(®_Q[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
}
|
||||
}
|
||||
else{
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(half2 &)ptr_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
|
||||
(half2 &)reg_Q[i] = (half2 &)input_q[idx_qkv + i*WARP_SIZE];
|
||||
}
|
||||
}
|
||||
int common_idx = lane_id_x2 + (parallel_idx * compMeta.stride[1]);
|
||||
int common_idx = lane_id_x2 + (parallel_idx * compMeta.max_kv_seqlen * compMeta.head_dim);
|
||||
|
||||
// Q*K
|
||||
#pragma unroll
|
||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2]);
|
||||
reg_P = 0;
|
||||
int idx_kvcache = common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.head_dim);
|
||||
// readin K & V
|
||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
*((half2*)(&ptr_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||
*((half2*)(®_K[i])) = *((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||
*((half2*)(®_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||
}
|
||||
}
|
||||
else{
|
||||
if(!is_fp16){
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(float2 &)temp[i] = (float2 &) input_k[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(&ptr_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
}
|
||||
}
|
||||
else{
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(half2 &)ptr_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
|
||||
}
|
||||
}
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&ptr_K[i]));
|
||||
}
|
||||
}
|
||||
// * V
|
||||
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
*((half2*)(&ptr_V[i])) = *((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE]));
|
||||
}
|
||||
}
|
||||
else{
|
||||
if(!is_fp16){
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
*((half2*)(®_K[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_K[i]));
|
||||
(float2 &)temp[i] = (float2 &) input_v[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(&ptr_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
*((half2*)(®_V[i])) = __float22half2_rn(*((float2*)(&temp[i])));
|
||||
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_V[i]));
|
||||
}
|
||||
}
|
||||
else{
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
(half2 &)ptr_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
|
||||
(half2 &)reg_K[i] = (half2 &)input_k[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(&((half*)input_k_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_K[i]));
|
||||
(half2 &)reg_V[i] = (half2 &)input_v[idx_qkv + i*WARP_SIZE];
|
||||
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(®_V[i]));
|
||||
}
|
||||
}
|
||||
for(int i = 0; i < 4; i += 2){
|
||||
*((half2*)(&((half*)input_v_cache)[idx_kvcache + i*WARP_SIZE])) = *((half2*)(&ptr_V[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Q*K
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++){
|
||||
ptr_K[i] = ptr_Q[i] * ptr_K[i];
|
||||
for (int i = 0; i < 4; i += 2){
|
||||
(half2 &)reg_K[i] = (half2 &)reg_Q[i] * (half2 &)reg_K[i];
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
|
||||
ptr_K[i] += __shfl_xor_sync(0xffffffff, ptr_K[i], offset);
|
||||
(half2 &)reg_K[i] += __shfl_xor_sync(0xffffffff, (half2 &)reg_K[i], offset);
|
||||
}
|
||||
ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]);
|
||||
(float2 &) temp[i] = __half22float2((half2 &)reg_K[i]);
|
||||
reg_P += (temp[i] + temp[i+1]);
|
||||
(float2 &) temp[i] = __half22float2((half2 &)reg_V[i]);
|
||||
}
|
||||
|
||||
// div sqrt(d)
|
||||
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
|
||||
reg_P /= sqrt(128.0);
|
||||
|
||||
// softmax
|
||||
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
|
||||
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
||||
reg_P = expf(reg_P);
|
||||
reg_sum += reg_P;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++)
|
||||
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], __half2float(ptr_V[i]), ptr_O[i]);
|
||||
reg_O[i] = fmaf(reg_P, temp[i], reg_O[i]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i ++)
|
||||
ptr_O[i] /= ptr_sum[0];
|
||||
reg_O[i] /= reg_sum;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < 4; i += 2)
|
||||
(half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = __float22half2_rn((float2 &)ptr_O[i]);
|
||||
(half2 &)output_O_temp[(lane_id_x2 + i*WARP_SIZE) + (blockIdx.z * compMeta.head_dim) + (parallel_idx * compMeta.head_dim * stride)] = __float22half2_rn((float2 &)reg_O[i]);
|
||||
if(lane_id_x2 == 0){
|
||||
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
|
||||
output_sum_temp[blockIdx.z + parallel_idx * stride] = reg_sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
||||
__global__ void _attention_kvcache_kernel_128_2(int64_t* position_id,
|
||||
T* output_matmul,
|
||||
AttentionKVCacheMetadata compMeta,
|
||||
half* output_O_temp,
|
||||
float* output_sum_temp) {
|
||||
int lane_id_x4 = threadIdx.x % WARP_SIZE * 4;
|
||||
int group_id = threadIdx.x / WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int parallel_idx = blockIdx.x;
|
||||
int offset = parallel_idx * compMeta.head_dim;
|
||||
|
||||
|
||||
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
||||
bool is_fp16 = sizeof(T) == 2 ? true : false;
|
||||
|
||||
if(size > 1){
|
||||
float ptr_O[4] = {0};
|
||||
float ptr_O_sum[4] = {0};
|
||||
float ptr_sum = 0;
|
||||
float ptr_sum_temp;
|
||||
half temp_half[4];
|
||||
#pragma unroll
|
||||
for(int i = 0; i < size; i ++){
|
||||
(float2 &)temp_half[0]
|
||||
= (float2 &)output_O_temp[lane_id_x4 + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||
for(int j = 0; j < 4; j += 2)
|
||||
(float2 &)ptr_O[j] = __half22float2((half2 &)temp_half[j]);
|
||||
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||
|
||||
if(size == 1){
|
||||
if(!is_fp16){
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] = fma(ptr_O[k], ptr_sum_temp, ptr_O_sum[k]);
|
||||
ptr_sum += ptr_sum_temp;
|
||||
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
|
||||
output_matmul[i + offset]
|
||||
= __half2float(output_O_temp[i + offset]);
|
||||
}
|
||||
else{
|
||||
#pragma unroll
|
||||
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x)
|
||||
output_matmul[i + offset]
|
||||
= output_O_temp[i + offset];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
__shared__ float shm_sum_temp[MAX_PARTITION];
|
||||
__shared__ float shm_sum[WARP_SIZE];
|
||||
float temp_sum = 0;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = threadIdx.x; i < size; i += blockDim.x){
|
||||
shm_sum_temp[i] = output_sum_temp[i + parallel_idx * size];
|
||||
temp_sum += shm_sum_temp[i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
|
||||
temp_sum += __shfl_down_sync(0xffffffff, temp_sum, offset);
|
||||
if(lane_id == 0)
|
||||
shm_sum[threadIdx.x/WARP_SIZE] = temp_sum;
|
||||
__syncthreads();
|
||||
temp_sum = lane_id < (size + WARP_SIZE - 1) / WARP_SIZE ? shm_sum[lane_id] : 0;
|
||||
|
||||
#pragma unroll
|
||||
for(int offset = WARP_SIZE/2; offset > 0; offset /= 2)
|
||||
temp_sum += __shfl_xor_sync(0xffffffff, temp_sum, offset);
|
||||
temp_sum = __fdividef(1.0f, temp_sum + 1e-6f);
|
||||
|
||||
#pragma unroll
|
||||
for(int i = threadIdx.x; i < compMeta.head_dim; i += blockDim.x){
|
||||
float acc = 0.0f;
|
||||
for(int j = 0; j < size; j ++){
|
||||
acc = fma(__half2float(output_O_temp[i + (j * compMeta.head_dim) + offset * size]) * shm_sum_temp[j], temp_sum, acc);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k ++)
|
||||
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||
|
||||
if(!is_fp16){
|
||||
(float4 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||
output_matmul[i + offset] = acc;
|
||||
}
|
||||
else{
|
||||
for(int j = 0; j < 4; j += 2)
|
||||
(half2 &)temp_half[j] = __float22half2_rn((float2 &)ptr_O_sum[j]);
|
||||
(float2 &)output_matmul[lane_id_x4 + (parallel_idx * compMeta.dimSize[3])] = (float2 &)temp_half[0];
|
||||
}
|
||||
}
|
||||
else{
|
||||
half temp_half[4];
|
||||
float temp_float[4];
|
||||
if(!is_fp16){
|
||||
(float2 &)temp_half[0]
|
||||
= (float2 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
|
||||
for(int i = 0; i < 4; i += 2)
|
||||
(float2 &)temp_float[i] = __half22float2((half2 &)temp_half[i]);
|
||||
(float4 &)output_matmul[lane_id_x4 + parallel_idx * compMeta.dimSize[3]] = (float4 &)temp_float[0];
|
||||
}
|
||||
else{
|
||||
(float2 &)output_matmul[lane_id_x4 + parallel_idx * compMeta.dimSize[3]]
|
||||
= (float2 &)output_O_temp[lane_id_x4 + parallel_idx * compMeta.dimSize[3]];
|
||||
output_matmul[i + offset] = __float2half(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -203,14 +200,14 @@ __global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
|||
namespace infini {
|
||||
void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cache,
|
||||
void *input_q, void *input_k,
|
||||
void *input_v, int *position_id, void *output_matmul,
|
||||
void *input_v, int64_t *position_id, void *output_matmul,
|
||||
const AttentionKVCacheMetadata &compMeta,
|
||||
float *output_O_temp, float *output_sum_temp) {
|
||||
IT_ASSERT(compMeta.dimSize[3] == 128 && (dType == 1 || dType == 10));
|
||||
IT_ASSERT(dType == 1 || dType == 10);
|
||||
|
||||
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
|
||||
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
|
||||
dim3 blockDim(BLOCKSIZE, 1);
|
||||
int gridsize_y = (compMeta.max_kv_seqlen - 1 + SEQ_UNIT) / SEQ_UNIT;
|
||||
dim3 gridDim(compMeta.num_heads, compMeta.num_seqs, gridsize_y);
|
||||
dim3 blockDim(WARP_SIZE, 1);
|
||||
|
||||
if(dType == 1){
|
||||
_attention_kvcache_kernel_128_1<float>
|
||||
|
@ -219,7 +216,7 @@ void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cach
|
|||
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
|
||||
_attention_kvcache_kernel_128_2<float>
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1], WARP_SIZE,
|
||||
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, (float*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
@ -230,7 +227,7 @@ void attention_kvcache_kernel(int dType, void *input_k_cache, void *input_v_cach
|
|||
position_id, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
|
||||
_attention_kvcache_kernel_128_2<half>
|
||||
<<<compMeta.dimSize[0]*compMeta.dimSize[1], WARP_SIZE,
|
||||
<<<compMeta.num_seqs*compMeta.num_heads, BLOCKSIZE_2,
|
||||
0, CUDAStream::getCurrentStream()>>>
|
||||
(position_id, (half*)output_matmul, compMeta, (half*)output_O_temp, output_sum_temp);
|
||||
}
|
||||
|
|
|
@ -18,17 +18,19 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
|||
const auto &inputShape = input->getDims();
|
||||
int nDims = input->getDims().size();
|
||||
|
||||
int size = input->size();
|
||||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
||||
IT_ASSERT(inputShape[0] == pos->getDims()[0] &&
|
||||
inputShape[1] == pos->getDims()[1]);
|
||||
int position_idx_dtype = op->getInputs()[0]->getDTypeIndex();
|
||||
IT_ASSERT(position_idx_dtype == 7);
|
||||
int dim_model = inputShape[2];
|
||||
int dim_head = 128;
|
||||
int hidden_stride = dim_model * inputShape[1];
|
||||
int dim_head = 128; // TODO: get dim_head from the framework
|
||||
int pos_stride = inputShape[1];
|
||||
int batchsize = inputShape[0];
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
|
||||
size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
rope_kernel(dType, pos->getRawDataPtr<int64_t *>(), inputData,
|
||||
outputData, dim_model, dim_head, batchsize, pos_stride);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -4,13 +4,15 @@
|
|||
#include "utils/small_array.h"
|
||||
|
||||
template <class T>
|
||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
||||
int dim_head, int hidden_stride, int pos_stride) {
|
||||
__global__ void _rope_kernel(int64_t* pos, void *in, void *out, int dim_model,
|
||||
int dim_head, int batchsize, int pos_stride) {
|
||||
int batch_id = blockIdx.x;
|
||||
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
|
||||
|
||||
int ith = blockIdx.z * blockDim.x + threadIdx.x;
|
||||
int col = ith % dim_head;
|
||||
int offset = batch_id * hidden_stride + blockIdx.y * dim_model;
|
||||
int batch_stride = pos_stride * dim_model;
|
||||
int offset = batch_id * batch_stride + blockIdx.y * dim_model;
|
||||
|
||||
if (ith >= dim_model)
|
||||
return;
|
||||
|
@ -34,7 +36,7 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
#define CASE(T) \
|
||||
_rope_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
(pos, input, output, dim_model, dim_head, batchsize, pos_stride);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
|
@ -79,10 +81,10 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
void rope_kernel(int dType, int64_t * pos, void *input, void *output,
|
||||
int dim_model, int dim_head, int batchsize, int pos_stride) {
|
||||
dim3 blocksize = dim3(32,1,1);
|
||||
dim3 gridsize = dim3(1, 1, dim_model/32);
|
||||
dim3 gridsize = dim3(batchsize, pos_stride, dim_model/32);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue