add shape information to the kvcache attention operator

This commit is contained in:
xiaonans 2024-04-11 14:52:39 +08:00
parent 4a5b9572bb
commit d000f9750c
2 changed files with 11 additions and 5 deletions

View File

@ -3,8 +3,7 @@
namespace infini {
/**
* @brief Fused Attention with KVCache input operator. All the input and output
* tensors should have the same rank except for the position_id.
* @brief Fused Attention with KVCache input operator.
*
*/
class AttentionKVCacheObj : public OperatorObj {
@ -16,12 +15,19 @@ class AttentionKVCacheObj : public OperatorObj {
*
* @param graph The computation graph that this operator belongs to.
* @param input_k_cache The k_cache input tensor.
* Shape: [batchsize, num_heads, k_cache_seq_length, head_dim]
* @param input_v_cache The v_cache input tensor.
* Shape: [batchsize, num_heads, v_cache_seq_length, head_dim]
* @param input_q The query input tensor.
* Shape: [batchsize, q_seq_length, model_dim]
* @param input_k The key input tensor.
* Shape: [batchsize, q_seq_length, model_dim]
* @param input_v The value input tensor.
* @param position_id The positon id of the query,
* Shape: [batchsize, q_seq_length, model_dim]
* @param position_id The positon id of the query.
* Shape: [batchsize, q_seq_length]
* @param output_matmul The query output tensor.
* Shape: [batchsize, q_seq_length, model_dim]
*/
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k,

View File

@ -8,8 +8,8 @@
// ASSUME SEQ_LEN OF Q IS 1
template <class T>
__global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
T* input_v_cache,
T* input_q,
T* input_v_cache,
T* input_q,
T* input_k,
T* input_v,
int64_t* position_id,