forked from jiuyuan/InfiniTensor
add shape information to the kvcache attention operator
This commit is contained in:
parent
4a5b9572bb
commit
d000f9750c
|
@ -3,8 +3,7 @@
|
|||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Fused Attention with KVCache input operator. All the input and output
|
||||
* tensors should have the same rank except for the position_id.
|
||||
* @brief Fused Attention with KVCache input operator.
|
||||
*
|
||||
*/
|
||||
class AttentionKVCacheObj : public OperatorObj {
|
||||
|
@ -16,12 +15,19 @@ class AttentionKVCacheObj : public OperatorObj {
|
|||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input_k_cache The k_cache input tensor.
|
||||
* Shape: [batchsize, num_heads, k_cache_seq_length, head_dim]
|
||||
* @param input_v_cache The v_cache input tensor.
|
||||
* Shape: [batchsize, num_heads, v_cache_seq_length, head_dim]
|
||||
* @param input_q The query input tensor.
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
* @param input_k The key input tensor.
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
* @param input_v The value input tensor.
|
||||
* @param position_id The positon id of the query,
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
* @param position_id The positon id of the query.
|
||||
* Shape: [batchsize, q_seq_length]
|
||||
* @param output_matmul The query output tensor.
|
||||
* Shape: [batchsize, q_seq_length, model_dim]
|
||||
*/
|
||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||
|
|
|
@ -8,8 +8,8 @@
|
|||
// ASSUME SEQ_LEN OF Q IS 1
|
||||
template <class T>
|
||||
__global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
|
||||
T* input_v_cache,
|
||||
T* input_q,
|
||||
T* input_v_cache,
|
||||
T* input_q,
|
||||
T* input_k,
|
||||
T* input_v,
|
||||
int64_t* position_id,
|
||||
|
|
Loading…
Reference in New Issue