[feature] add fused attention_kvcache operator support (#179)

* [feature] add fused attention_kvcache operator support

* add test to attention_kvcache op

* Add space line at EOF

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
xiaonans 2023-11-14 23:44:22 +08:00 committed by GitHub
parent f22fa2766e
commit 965df4e294
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 375 additions and 0 deletions

View File

@ -64,6 +64,9 @@ class GraphHandlerObj {
Tensor transpose(Tensor data, Tensor transposed, Shape perm); Tensor transpose(Tensor data, Tensor transposed, Shape perm);
Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor concat(TensorVec inputs, Tensor output, int dim);
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis, TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
int num_outputs); int num_outputs);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);

View File

@ -25,6 +25,7 @@ struct OpType {
Asinh, // Unary Asinh, // Unary
Atan, // Unary Atan, // Unary
Atanh, // Unary Atanh, // Unary
AttentionKVCache, // Fusion
AveragePool, // Pool AveragePool, // Pool
BatchNormalization, // BatchNormalization, //
Bernoulli, // Bernoulli, //

View File

@ -0,0 +1,15 @@
#pragma once
#include <cstdio>
struct AttentionKVCacheMetadata {
int dimSize[4];
int stride[4];
};
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k, float *input_v,
int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta);
} // namespace infini

View File

@ -0,0 +1,43 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Fused Attention with KVCache input operator. All the input and output
* tensors should have the same rank except for the position_id.
*
*/
class AttentionKVCacheObj : public OperatorObj {
int dim;
public:
/**
* @brief Construct a new AttentionKVCache object.
*
* @param graph The computation graph that this operator belongs to.
* @param input_k_cache The k_cache input tensor.
* @param input_v_cache The v_cache input tensor.
* @param input_q The query input tensor.
* @param input_k The key input tensor.
* @param input_v The value input tensor.
* @param position_id The positon id of the query,
* @param output_matmul The query output tensor.
*/
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id,
Tensor output_matmul);
OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 6; }
int numOutputs() const override { return 1; }
int getDim() const { return dim; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -46,6 +46,9 @@ class OnnxStub:
model = model_simp model = model_simp
except ValidationError: except ValidationError:
pass pass
except RuntimeError:
pass
self.inputs: Dict[str, backend.Tensor] = {} self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {} self.initializer: Dict[int, TensorProto] = {}
@ -560,6 +563,16 @@ class OnnxStub:
(attr.i for attr in node.attribute if attr.name == "axis") (attr.i for attr in node.attribute if attr.name == "axis")
), ),
) )
elif node.op_type == "AttentionKVCache":
tensors[node.output[0]] = self.handler.attentionKVCache(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.input[2]],
tensors[node.input[3]],
tensors[node.input[4]],
tensors[node.input[5]],
tensors.get(node.output[0]),
)
elif node.op_type == "Split": elif node.op_type == "Split":
for name, tensor in zip( for name, tensor in zip(
node.output, node.output,

View File

@ -1,6 +1,7 @@
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/all_gather.h" #include "operators/all_gather.h"
#include "operators/all_reduce.h" #include "operators/all_reduce.h"
#include "operators/attention_kvcache.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/broadcast.h" #include "operators/broadcast.h"
#include "operators/concat.h" #include "operators/concat.h"
@ -239,6 +240,27 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
} }
} }
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v,
Tensor position_id,
Tensor output_matmul) {
if (output_matmul) {
g->addOpWithOutputs<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul);
return {output_matmul};
} else {
return g
->addOp<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul)
->getOutput();
}
}
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs, TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
int axis, int num_outputs) { int axis, int num_outputs) {
if (outputs) { if (outputs) {

View File

@ -489,6 +489,7 @@ void init_graph_builder(py::module &m) {
.def("depthToSpace", &Handler::depthToSpace, policy::move) .def("depthToSpace", &Handler::depthToSpace, policy::move)
.def("reshape", &Handler::reshape, policy::move) .def("reshape", &Handler::reshape, policy::move)
.def("concat", &Handler::concat, policy::move) .def("concat", &Handler::concat, policy::move)
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
.def("split", &Handler::split, policy::move) .def("split", &Handler::split, policy::move)
.def("gather", &Handler::gather, policy::move) .def("gather", &Handler::gather, policy::move)
.def("gatherElements", &Handler::gatherElements, policy::move) .def("gatherElements", &Handler::gatherElements, policy::move)

View File

@ -0,0 +1,52 @@
#include "operators/attention_kvcache.h"
#include "cuda/cuda_attention_kvcache.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include <functional>
namespace infini {
class AttentionKVCacheCompute {
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
Tensor tensor) const {
int nDims = tensor->getRank();
auto strides = tensor->getStride();
IT_ASSERT(nDims == 4);
IT_ASSERT(strides.size() == (size_t)nDims);
for (int i = 0; i < nDims; ++i) {
metadata.dimSize[i] = tensor->getDims().at(i);
metadata.stride[i] = strides.at(i);
}
}
public:
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id,
Tensor output_matmul) const {
AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache);
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
input_v_cache->getRawDataPtr<float *>(),
input_q->getRawDataPtr<float *>(),
input_k->getRawDataPtr<float *>(),
input_v->getRawDataPtr<float *>(),
position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<float *>(),
metadata);
}
};
class AttentionKVCacheCuda : private AttentionKVCacheCompute,
public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5],
_op->getOutputs()[0]);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32,
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32");
} // namespace infini

View File

@ -0,0 +1,128 @@
#include "cuda/cuda_common.h"
#include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE
#define SEQ_UNIT 64
__global__ void _attention_kvcache_kernel(float* input_k_cache,
float* input_v_cache,
float* input_q,
float* input_k,
float* input_v,
int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return;
float ptr_V[SEQ_UNIT*2];
float ptr_K[SEQ_UNIT*2];
float ptr_Q[2];
float ptr_P[SEQ_UNIT];
float ptr_O[2];
float ptr_max[1];
float ptr_sum[1];
float ptr_max_last[1];
float ptr_sum_last[1];
float ptr_O_last[2];
(float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)];
int SEQ_LENGTH = position_id[0] + 1;
int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]);
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
ptr_max_last[0] = ptr_max[0];
ptr_sum_last[0] = ptr_sum[0];
(float2 &)ptr_O_last[0] = (float2 &)ptr_O[0];
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
= (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
= (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
(float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float2 &)ptr_K[idx_SEQ_UNIT * 2];
}
ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2];
ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset);
}
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)];
}
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= 8;
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
}
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
ptr_sum[0] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
}
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
ptr_O[0] = 0;
ptr_O[1] = 0;
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
= (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
}
else{
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
= (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
(float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float2 &)ptr_V[idx_SEQ_UNIT * 2];
}
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]);
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]);
}
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
}
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
}
namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k,
float *input_v, int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta) {
IT_ASSERT(compMeta.dimSize[3] == 64);
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1);
dim3 blockDim(BLOCKSIZE, 1);
_attention_kvcache_kernel<<<gridDim, blockDim>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta);
}
} // namespace infini

View File

@ -0,0 +1,55 @@
#include "operators/attention_kvcache.h"
#include "utils/operator_utils.h"
namespace infini {
AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v,
Tensor position_id,
Tensor output_matmul)
: OperatorObj(OpType::AttentionKVCache,
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
input_v, position_id},
{output_matmul}) {
int rank = inputs[0]->getRank();
IT_ASSERT(rank == 4);
dim = 2;
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>>
AttentionKVCacheObj::inferShape(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 6);
Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim);
dims[dim] = n + 1;
return {{inputs[2]->getDims()}};
}
std::string AttentionKVCacheObj::toString() const {
std::ostringstream os;
os << "AttentionKVCache[" << getGuid() << "]";
os << "(";
for (auto input : inputs)
os << vecToString(input->getDims()) << ",";
os << "dim=" << dim << ",";
os << "input=";
for (auto input : inputs)
os << input->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> AttentionKVCacheObj::getWorkloadVector() const {
vector<int> ret = getOutputs()[0]->getDims();
ret.emplace(ret.begin(), (int)inputs.size());
ret.emplace(ret.begin(), dim);
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> AttentionKVCacheObj::getOpAttrVector() const {
return {type.underlying(), dim};
}
} // namespace infini

View File

@ -0,0 +1,42 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/attention_kvcache.h"
#include "test.h"
namespace infini {
TEST(AttentionKVCache, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
auto op = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
position_id_d, nullptr);
gCuda->dataMalloc();
input_q_d->setData(OneGenerator());
input_k_d->setData(OneGenerator());
input_v_d->setData(OneGenerator());
position_id_d->setData(IncrementalGenerator());
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(oCpu->equalData(vector<float>{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
}
} // namespace infini