From 2436ccb8684b11e5592111d1789b04c9f30f8843 Mon Sep 17 00:00:00 2001 From: xiaonans Date: Fri, 10 Nov 2023 10:51:44 +0800 Subject: [PATCH] [feature] add fused attention_kvcache operator support --- include/core/graph_handler.h | 3 + include/core/op_type.h | 1 + include/cuda/cuda_attention_kvcache.h | 15 +++ include/operators/attention_kvcache.h | 43 ++++++++ pyinfinitensor/src/pyinfinitensor/onnx.py | 13 +++ src/core/graph_handler.cc | 22 ++++ src/ffi/ffi_infinitensor.cc | 1 + src/kernels/cuda/attention_kvcache.cc | 51 +++++++++ src/kernels/cuda/attention_kvcache.cu | 128 ++++++++++++++++++++++ src/operators/attention_kvcache.cc | 55 ++++++++++ 10 files changed, 332 insertions(+) create mode 100644 include/cuda/cuda_attention_kvcache.h create mode 100644 include/operators/attention_kvcache.h create mode 100644 src/kernels/cuda/attention_kvcache.cc create mode 100644 src/kernels/cuda/attention_kvcache.cu create mode 100644 src/operators/attention_kvcache.cc diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 87e909f8..0b7727e3 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -64,6 +64,9 @@ class GraphHandlerObj { Tensor transpose(Tensor data, Tensor transposed, Shape perm); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor concat(TensorVec inputs, Tensor output, int dim); + Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, + Tensor input_q, Tensor input_k, Tensor input_v, + Tensor position_id, Tensor output_matmul); TensorVec split(Tensor input, std::optional outputs, int axis, int num_outputs); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/core/op_type.h b/include/core/op_type.h index ad2e6acb..91a0b99a 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -25,6 +25,7 @@ struct OpType { Asinh, // Unary Atan, // Unary Atanh, // Unary + AttentionKVCache, // Fusion AveragePool, // Pool BatchNormalization, // Bernoulli, // diff --git a/include/cuda/cuda_attention_kvcache.h b/include/cuda/cuda_attention_kvcache.h new file mode 100644 index 00000000..08774b47 --- /dev/null +++ b/include/cuda/cuda_attention_kvcache.h @@ -0,0 +1,15 @@ +#pragma once +#include + +struct AttentionKVCacheMetadata { + int dimSize[4]; + int stride[4]; +}; + +namespace infini { +void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, + float *input_q, float *input_k, float *input_v, + int *position_id, float *output_matmul, + const AttentionKVCacheMetadata &compMeta); + +} // namespace infini \ No newline at end of file diff --git a/include/operators/attention_kvcache.h b/include/operators/attention_kvcache.h new file mode 100644 index 00000000..98457c2c --- /dev/null +++ b/include/operators/attention_kvcache.h @@ -0,0 +1,43 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief Fused Attention with KVCache input operator. All the input and output + * tensors should have the same rank except for the position_id. + * + */ +class AttentionKVCacheObj : public OperatorObj { + int dim; + + public: + /** + * @brief Construct a new AttentionKVCache object. + * + * @param graph The computation graph that this operator belongs to. + * @param input_k_cache The k_cache input tensor. + * @param input_v_cache The v_cache input tensor. + * @param input_q The query input tensor. + * @param input_k The key input tensor. + * @param input_v The value input tensor. + * @param position_id The positon id of the query, + * @param output_matmul The query output tensor. + */ + AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, Tensor input_k, + Tensor input_v, Tensor position_id, + Tensor output_matmul); + OP_CLONE(AttentionKVCacheObj); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 6; } + int numOutputs() const override { return 1; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini \ No newline at end of file diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 6d0da9f8..52ad7c8c 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -46,6 +46,9 @@ class OnnxStub: model = model_simp except ValidationError: pass + except RuntimeError: + pass + self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} self.initializer: Dict[int, TensorProto] = {} @@ -545,6 +548,16 @@ class OnnxStub: (attr.i for attr in node.attribute if attr.name == "axis") ), ) + elif node.op_type == "AttentionKVCache": + tensors[node.output[0]] = self.handler.attentionKVCache( + tensors[node.input[0]], + tensors[node.input[1]], + tensors[node.input[2]], + tensors[node.input[3]], + tensors[node.input[4]], + tensors[node.input[5]], + tensors.get(node.output[0]), + ) elif node.op_type == "Split": for name, tensor in zip( node.output, diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 77fbcf2d..6e8cc2f0 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/all_gather.h" #include "operators/all_reduce.h" +#include "operators/attention_kvcache.h" #include "operators/batch_norm.h" #include "operators/broadcast.h" #include "operators/concat.h" @@ -239,6 +240,27 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { } } +Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) { + if (output_matmul) { + g->addOpWithOutputs( + std::move(input_k_cache), std::move(input_v_cache), + std::move(input_q), std::move(input_k), std::move(input_v), + std::move(position_id), output_matmul); + return {output_matmul}; + } else { + return g + ->addOp( + std::move(input_k_cache), std::move(input_v_cache), + std::move(input_q), std::move(input_k), std::move(input_v), + std::move(position_id), output_matmul) + ->getOutput(); + } +} + TensorVec GraphHandlerObj::split(Tensor input, std::optional outputs, int axis, int num_outputs) { if (outputs) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index e1a726c3..77fd4c0f 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -479,6 +479,7 @@ void init_graph_builder(py::module &m) { .def("transpose", &Handler::transpose, policy::move) .def("reshape", &Handler::reshape, policy::move) .def("concat", &Handler::concat, policy::move) + .def("attentionKVCache", &Handler::attentionKVCache, policy::move) .def("split", &Handler::split, policy::move) .def("gather", &Handler::gather, policy::move) .def("gatherElements", &Handler::gatherElements, policy::move) diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc new file mode 100644 index 00000000..aef867e0 --- /dev/null +++ b/src/kernels/cuda/attention_kvcache.cc @@ -0,0 +1,51 @@ +#include "operators/attention_kvcache.h" +#include "cuda/cuda_attention_kvcache.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include + +namespace infini { + +class AttentionKVCacheCompute { + void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata, + Tensor tensor) const { + int nDims = tensor->getRank(); + auto strides = tensor->getStride(); + IT_ASSERT(strides.size() == (size_t)nDims); + for (int i = 0; i < nDims; ++i) { + metadata.dimSize[i] = tensor->getDims().at(i); + metadata.stride[i] = strides.at(i); + } + } + + public: + void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, Tensor position_id, + Tensor output_matmul) const { + AttentionKVCacheMetadata metadata; + initAttentionKVCacheMetadata(metadata, input_v_cache); + + attention_kvcache_kernel(input_k_cache->getRawDataPtr(), + input_v_cache->getRawDataPtr(), + input_q->getRawDataPtr(), + input_k->getRawDataPtr(), + input_v->getRawDataPtr(), + position_id->getRawDataPtr(), + output_matmul->getRawDataPtr(), + metadata); + } +}; + +class AttentionKVCacheCuda : private AttentionKVCacheCompute, + public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + do_compute(_op->getInputs()[0], _op->getInputs()[1], + _op->getInputs()[2], _op->getInputs()[3], + _op->getInputs()[4], _op->getInputs()[5], + _op->getOutputs()[0]); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32, + AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32"); +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu new file mode 100644 index 00000000..1a2b1ab4 --- /dev/null +++ b/src/kernels/cuda/attention_kvcache.cu @@ -0,0 +1,128 @@ +#include "cuda/cuda_common.h" +#include "cuda/cuda_attention_kvcache.h" +#define WARP_SIZE 32 +#define BLOCKSIZE WARP_SIZE +#define SEQ_UNIT 64 + +__global__ void _attention_kvcache_kernel(float* input_k_cache, + float* input_v_cache, + float* input_q, + float* input_k, + float* input_v, + int* position_id, + float* output_matmul, + AttentionKVCacheMetadata compMeta) { + int lane_id = threadIdx.x % WARP_SIZE; + int group_id = threadIdx.x / WARP_SIZE; + int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + + if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) + return; + + float ptr_V[SEQ_UNIT*2]; + float ptr_K[SEQ_UNIT*2]; + float ptr_Q[2]; + float ptr_P[SEQ_UNIT]; + + float ptr_O[2]; + float ptr_max[1]; + float ptr_sum[1]; + + float ptr_max_last[1]; + float ptr_sum_last[1]; + float ptr_O_last[2]; + + (float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)]; + + int SEQ_LENGTH = position_id[0] + 1; + + int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]); + + + for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){ + ptr_max_last[0] = ptr_max[0]; + ptr_sum_last[0] = ptr_sum[0]; + (float2 &)ptr_O_last[0] = (float2 &)ptr_O[0]; + + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ + (float2 &)ptr_K[idx_SEQ_UNIT * 2] + = (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float2 &)ptr_K[idx_SEQ_UNIT * 2] + = (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; + (float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = + (float2 &)ptr_K[idx_SEQ_UNIT * 2]; + } + ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2]; + ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1]; + + #pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset); + } + ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2]; + #pragma unroll + for (int offset = 16; offset > 0; offset /= 2){ + ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset); + } + ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)]; + } + + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); + ptr_P[idx_SEQ_UNIT] /= 8; + ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]); + } + ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]); + + ptr_sum[0] = 0; + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]); + ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; + } + ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0]; + + ptr_O[0] = 0; + ptr_O[1] = 0; + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ + (float2 &)ptr_V[idx_SEQ_UNIT * 2] + = (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float2 &)ptr_V[idx_SEQ_UNIT * 2] + = (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; + (float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = + (float2 &)ptr_V[idx_SEQ_UNIT * 2]; + } + + ptr_P[idx_SEQ_UNIT] /= ptr_sum[0]; + + ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]); + ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]); + } + ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; + ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; + } + (float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0]; +} + +namespace infini { +void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, + float *input_v, int *position_id, float *output_matmul, + const AttentionKVCacheMetadata &compMeta) { + IT_ASSERT(compMeta.dimSize[3] == 64); + dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1); + dim3 blockDim(BLOCKSIZE, 1); + + _attention_kvcache_kernel<<>>( + input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta); +} + +} // namespace infini \ No newline at end of file diff --git a/src/operators/attention_kvcache.cc b/src/operators/attention_kvcache.cc new file mode 100644 index 00000000..20e3ecfb --- /dev/null +++ b/src/operators/attention_kvcache.cc @@ -0,0 +1,55 @@ +#include "operators/attention_kvcache.h" +#include "utils/operator_utils.h" + +namespace infini { +AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) + : OperatorObj(OpType::AttentionKVCache, + TensorVec{input_k_cache, input_v_cache, input_q, input_k, + input_v, position_id}, + {output_matmul}) { + int rank = inputs[0]->getRank(); + IT_ASSERT(rank == 4); + dim = 2; + IT_ASSERT(checkValid(graph)); +} + +optional> +AttentionKVCacheObj::inferShape(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() == 6); + Shape dims = inputs[0]->getDims(); + ShapeElem n = dims.at(dim); + dims[dim] = n + 1; + return {{inputs[2]->getDims()}}; +} + +std::string AttentionKVCacheObj::toString() const { + std::ostringstream os; + os << "AttentionKVCache[" << getGuid() << "]"; + os << "("; + for (auto input : inputs) + os << vecToString(input->getDims()) << ","; + os << "dim=" << dim << ","; + os << "input="; + for (auto input : inputs) + os << input->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector AttentionKVCacheObj::getWorkloadVector() const { + vector ret = getOutputs()[0]->getDims(); + ret.emplace(ret.begin(), (int)inputs.size()); + ret.emplace(ret.begin(), dim); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector AttentionKVCacheObj::getOpAttrVector() const { + return {type.underlying(), dim}; +} + +} // namespace infini \ No newline at end of file