diff --git a/include/operators/attention_kvcache.h b/include/operators/attention_kvcache.h index 351952c9..dfe21c1e 100644 --- a/include/operators/attention_kvcache.h +++ b/include/operators/attention_kvcache.h @@ -3,8 +3,7 @@ namespace infini { /** - * @brief Fused Attention with KVCache input operator. All the input and output - * tensors should have the same rank except for the position_id. + * @brief Fused Attention with KVCache input operator. * */ class AttentionKVCacheObj : public OperatorObj { @@ -16,12 +15,19 @@ class AttentionKVCacheObj : public OperatorObj { * * @param graph The computation graph that this operator belongs to. * @param input_k_cache The k_cache input tensor. + * Shape: [batchsize, num_heads, k_cache_seq_length, head_dim] * @param input_v_cache The v_cache input tensor. + * Shape: [batchsize, num_heads, v_cache_seq_length, head_dim] * @param input_q The query input tensor. + * Shape: [batchsize, q_seq_length, model_dim] * @param input_k The key input tensor. + * Shape: [batchsize, q_seq_length, model_dim] * @param input_v The value input tensor. - * @param position_id The positon id of the query, + * Shape: [batchsize, q_seq_length, model_dim] + * @param position_id The positon id of the query. + * Shape: [batchsize, q_seq_length] * @param output_matmul The query output tensor. + * Shape: [batchsize, q_seq_length, model_dim] */ AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index f66a90ab..3499f9f0 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -8,8 +8,8 @@ // ASSUME SEQ_LEN OF Q IS 1 template __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache, - T* input_v_cache, - T* input_q, + T* input_v_cache, + T* input_q, T* input_k, T* input_v, int64_t* position_id,