[feature] add fused attention_kvcache operator support

This commit is contained in:
xiaonans 2023-11-10 10:51:44 +08:00
parent d3e7543291
commit 2436ccb868
10 changed files with 332 additions and 0 deletions

View File

@ -64,6 +64,9 @@ class GraphHandlerObj {
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
Tensor concat(TensorVec inputs, Tensor output, int dim);
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
int num_outputs);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);

View File

@ -25,6 +25,7 @@ struct OpType {
Asinh, // Unary
Atan, // Unary
Atanh, // Unary
AttentionKVCache, // Fusion
AveragePool, // Pool
BatchNormalization, //
Bernoulli, //

View File

@ -0,0 +1,15 @@
#pragma once
#include <cstdio>
struct AttentionKVCacheMetadata {
int dimSize[4];
int stride[4];
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k, float *input_v,
int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta);
} // namespace infini

View File

@ -0,0 +1,43 @@
#pragma once
#include "core/operator.h"
namespace infini {
* @brief Fused Attention with KVCache input operator. All the input and output
* tensors should have the same rank except for the position_id.
class AttentionKVCacheObj : public OperatorObj {
int dim;
* @brief Construct a new AttentionKVCache object.
* @param graph The computation graph that this operator belongs to.
* @param input_k_cache The k_cache input tensor.
* @param input_v_cache The v_cache input tensor.
* @param input_q The query input tensor.
* @param input_k The key input tensor.
* @param input_v The value input tensor.
* @param position_id The positon id of the query,
* @param output_matmul The query output tensor.
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id,
Tensor output_matmul);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 6; }
int numOutputs() const override { return 1; }
int getDim() const { return dim; }
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
} // namespace infini

View File

@ -46,6 +46,9 @@ class OnnxStub:
model = model_simp
except ValidationError:
except RuntimeError:
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {}
@ -545,6 +548,16 @@ class OnnxStub:
(attr.i for attr in node.attribute if attr.name == "axis")
elif node.op_type == "AttentionKVCache":
tensors[node.output[0]] = self.handler.attentionKVCache(
elif node.op_type == "Split":
for name, tensor in zip(

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h"
#include "operators/all_gather.h"
#include "operators/all_reduce.h"
#include "operators/attention_kvcache.h"
#include "operators/batch_norm.h"
#include "operators/broadcast.h"
#include "operators/concat.h"
@ -239,6 +240,27 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v,
Tensor position_id,
Tensor output_matmul) {
if (output_matmul) {
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul);
return {output_matmul};
} else {
return g
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul)
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
int axis, int num_outputs) {
if (outputs) {

View File

@ -479,6 +479,7 @@ void init_graph_builder(py::module &m) {
.def("transpose", &Handler::transpose, policy::move)
.def("reshape", &Handler::reshape, policy::move)
.def("concat", &Handler::concat, policy::move)
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
.def("split", &Handler::split, policy::move)
.def("gather", &Handler::gather, policy::move)
.def("gatherElements", &Handler::gatherElements, policy::move)

View File

@ -0,0 +1,51 @@
#include "operators/attention_kvcache.h"
#include "cuda/cuda_attention_kvcache.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include <functional>
namespace infini {
class AttentionKVCacheCompute {
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
Tensor tensor) const {
int nDims = tensor->getRank();
auto strides = tensor->getStride();
IT_ASSERT(strides.size() == (size_t)nDims);
for (int i = 0; i < nDims; ++i) {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id,
Tensor output_matmul) const {
AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache);
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
input_v_cache->getRawDataPtr<float *>(),
input_q->getRawDataPtr<float *>(),
input_k->getRawDataPtr<float *>(),
input_v->getRawDataPtr<float *>(),
position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<float *>(),
class AttentionKVCacheCuda : private AttentionKVCacheCompute,
public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5],
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32,
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32");
} // namespace infini

View File

@ -0,0 +1,128 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32
#define SEQ_UNIT 64
__global__ void _attention_kvcache_kernel(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
float ptr_V[SEQ_UNIT*2];
float ptr_K[SEQ_UNIT*2];
float ptr_Q[2];
float ptr_P[SEQ_UNIT];
float ptr_O[2];
float ptr_max[1];
float ptr_sum[1];
float ptr_max_last[1];
float ptr_sum_last[1];
float ptr_O_last[2];
(float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)];
int SEQ_LENGTH = position_id[0] + 1;
int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]);
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
ptr_max_last[0] = ptr_max[0];
ptr_sum_last[0] = ptr_sum[0];
(float2 &)ptr_O_last[0] = (float2 &)ptr_O[0];
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
= (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
= (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
(float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float2 &)ptr_K[idx_SEQ_UNIT * 2];
ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2];
ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset);
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset);
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)];
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= 8;
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
ptr_sum[0] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
ptr_O[0] = 0;
ptr_O[1] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
= (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
= (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
(float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float2 &)ptr_V[idx_SEQ_UNIT * 2];
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]);
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]);
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k,
float *input_v, int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta) {
IT_ASSERT(compMeta.dimSize[3] == 64);
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1);
dim3 blockDim(BLOCKSIZE, 1);
_attention_kvcache_kernel<<<gridDim, blockDim>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta);
} // namespace infini

View File

@ -0,0 +1,55 @@
#include "operators/attention_kvcache.h"
#include "utils/operator_utils.h"
namespace infini {
AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v,
Tensor position_id,
Tensor output_matmul)
: OperatorObj(OpType::AttentionKVCache,
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
input_v, position_id},
{output_matmul}) {
int rank = inputs[0]->getRank();
IT_ASSERT(rank == 4);
dim = 2;
AttentionKVCacheObj::inferShape(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 6);
Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim);
dims[dim] = n + 1;
return {{inputs[2]->getDims()}};
std::string AttentionKVCacheObj::toString() const {
std::ostringstream os;
os << "AttentionKVCache[" << getGuid() << "]";
os << "(";
for (auto input : inputs)
os << vecToString(input->getDims()) << ",";
os << "dim=" << dim << ",";
os << "input=";
for (auto input : inputs)
os << input->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
vector<int> AttentionKVCacheObj::getWorkloadVector() const {
vector<int> ret = getOutputs()[0]->getDims();
ret.emplace(ret.begin(), (int)inputs.size());
ret.emplace(ret.begin(), dim);
ret.emplace(ret.begin(), type.underlying());
return ret;
vector<int> AttentionKVCacheObj::getOpAttrVector() const {
return {type.underlying(), dim};
} // namespace infini