forked from jiuyuan/InfiniTensor
add shape information to the kvcache attention operator
This commit is contained in:
parent
4a5b9572bb
commit
d000f9750c
|
@ -3,8 +3,7 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
/**
|
/**
|
||||||
* @brief Fused Attention with KVCache input operator. All the input and output
|
* @brief Fused Attention with KVCache input operator.
|
||||||
* tensors should have the same rank except for the position_id.
|
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
class AttentionKVCacheObj : public OperatorObj {
|
class AttentionKVCacheObj : public OperatorObj {
|
||||||
|
@ -16,12 +15,19 @@ class AttentionKVCacheObj : public OperatorObj {
|
||||||
*
|
*
|
||||||
* @param graph The computation graph that this operator belongs to.
|
* @param graph The computation graph that this operator belongs to.
|
||||||
* @param input_k_cache The k_cache input tensor.
|
* @param input_k_cache The k_cache input tensor.
|
||||||
|
* Shape: [batchsize, num_heads, k_cache_seq_length, head_dim]
|
||||||
* @param input_v_cache The v_cache input tensor.
|
* @param input_v_cache The v_cache input tensor.
|
||||||
|
* Shape: [batchsize, num_heads, v_cache_seq_length, head_dim]
|
||||||
* @param input_q The query input tensor.
|
* @param input_q The query input tensor.
|
||||||
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
* @param input_k The key input tensor.
|
* @param input_k The key input tensor.
|
||||||
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
* @param input_v The value input tensor.
|
* @param input_v The value input tensor.
|
||||||
* @param position_id The positon id of the query,
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
|
* @param position_id The positon id of the query.
|
||||||
|
* Shape: [batchsize, q_seq_length]
|
||||||
* @param output_matmul The query output tensor.
|
* @param output_matmul The query output tensor.
|
||||||
|
* Shape: [batchsize, q_seq_length, model_dim]
|
||||||
*/
|
*/
|
||||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||||
|
|
Loading…
Reference in New Issue