diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 61826893..c91c4901 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -64,6 +64,9 @@ class GraphHandlerObj { Tensor transpose(Tensor data, Tensor transposed, Shape perm); Tensor reshape(Tensor data, Tensor reshaped, Shape shape); Tensor concat(TensorVec inputs, Tensor output, int dim); + Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, + Tensor input_q, Tensor input_k, Tensor input_v, + Tensor position_id, Tensor output_matmul); TensorVec split(Tensor input, std::optional outputs, int axis, int num_outputs); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/core/op_type.h b/include/core/op_type.h index ad2e6acb..91a0b99a 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -25,6 +25,7 @@ struct OpType { Asinh, // Unary Atan, // Unary Atanh, // Unary + AttentionKVCache, // Fusion AveragePool, // Pool BatchNormalization, // Bernoulli, // diff --git a/include/cuda/cuda_attention_kvcache.h b/include/cuda/cuda_attention_kvcache.h new file mode 100644 index 00000000..880a814f --- /dev/null +++ b/include/cuda/cuda_attention_kvcache.h @@ -0,0 +1,15 @@ +#pragma once +#include + +struct AttentionKVCacheMetadata { + int dimSize[4]; + int stride[4]; +}; + +namespace infini { +void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, + float *input_q, float *input_k, float *input_v, + int *position_id, float *output_matmul, + const AttentionKVCacheMetadata &compMeta); + +} // namespace infini diff --git a/include/operators/attention_kvcache.h b/include/operators/attention_kvcache.h new file mode 100644 index 00000000..f319eb6c --- /dev/null +++ b/include/operators/attention_kvcache.h @@ -0,0 +1,43 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief Fused Attention with KVCache input operator. All the input and output + * tensors should have the same rank except for the position_id. + * + */ +class AttentionKVCacheObj : public OperatorObj { + int dim; + + public: + /** + * @brief Construct a new AttentionKVCache object. + * + * @param graph The computation graph that this operator belongs to. + * @param input_k_cache The k_cache input tensor. + * @param input_v_cache The v_cache input tensor. + * @param input_q The query input tensor. + * @param input_k The key input tensor. + * @param input_v The value input tensor. + * @param position_id The positon id of the query, + * @param output_matmul The query output tensor. + */ + AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, Tensor input_k, + Tensor input_v, Tensor position_id, + Tensor output_matmul); + OP_CLONE(AttentionKVCacheObj); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 6; } + int numOutputs() const override { return 1; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index cc5498f9..7360002a 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -46,6 +46,9 @@ class OnnxStub: model = model_simp except ValidationError: pass + except RuntimeError: + pass + self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} self.initializer: Dict[int, TensorProto] = {} @@ -560,6 +563,16 @@ class OnnxStub: (attr.i for attr in node.attribute if attr.name == "axis") ), ) + elif node.op_type == "AttentionKVCache": + tensors[node.output[0]] = self.handler.attentionKVCache( + tensors[node.input[0]], + tensors[node.input[1]], + tensors[node.input[2]], + tensors[node.input[3]], + tensors[node.input[4]], + tensors[node.input[5]], + tensors.get(node.output[0]), + ) elif node.op_type == "Split": for name, tensor in zip( node.output, diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index ddf53884..32b99b63 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -1,6 +1,7 @@ #include "core/graph_handler.h" #include "operators/all_gather.h" #include "operators/all_reduce.h" +#include "operators/attention_kvcache.h" #include "operators/batch_norm.h" #include "operators/broadcast.h" #include "operators/concat.h" @@ -239,6 +240,27 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { } } +Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) { + if (output_matmul) { + g->addOpWithOutputs( + std::move(input_k_cache), std::move(input_v_cache), + std::move(input_q), std::move(input_k), std::move(input_v), + std::move(position_id), output_matmul); + return {output_matmul}; + } else { + return g + ->addOp( + std::move(input_k_cache), std::move(input_v_cache), + std::move(input_q), std::move(input_k), std::move(input_v), + std::move(position_id), output_matmul) + ->getOutput(); + } +} + TensorVec GraphHandlerObj::split(Tensor input, std::optional outputs, int axis, int num_outputs) { if (outputs) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 3612269e..ca427dab 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -489,6 +489,7 @@ void init_graph_builder(py::module &m) { .def("depthToSpace", &Handler::depthToSpace, policy::move) .def("reshape", &Handler::reshape, policy::move) .def("concat", &Handler::concat, policy::move) + .def("attentionKVCache", &Handler::attentionKVCache, policy::move) .def("split", &Handler::split, policy::move) .def("gather", &Handler::gather, policy::move) .def("gatherElements", &Handler::gatherElements, policy::move) diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc new file mode 100644 index 00000000..0d21603a --- /dev/null +++ b/src/kernels/cuda/attention_kvcache.cc @@ -0,0 +1,52 @@ +#include "operators/attention_kvcache.h" +#include "cuda/cuda_attention_kvcache.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include + +namespace infini { + +class AttentionKVCacheCompute { + void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata, + Tensor tensor) const { + int nDims = tensor->getRank(); + auto strides = tensor->getStride(); + IT_ASSERT(nDims == 4); + IT_ASSERT(strides.size() == (size_t)nDims); + for (int i = 0; i < nDims; ++i) { + metadata.dimSize[i] = tensor->getDims().at(i); + metadata.stride[i] = strides.at(i); + } + } + + public: + void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, Tensor position_id, + Tensor output_matmul) const { + AttentionKVCacheMetadata metadata; + initAttentionKVCacheMetadata(metadata, input_v_cache); + + attention_kvcache_kernel(input_k_cache->getRawDataPtr(), + input_v_cache->getRawDataPtr(), + input_q->getRawDataPtr(), + input_k->getRawDataPtr(), + input_v->getRawDataPtr(), + position_id->getRawDataPtr(), + output_matmul->getRawDataPtr(), + metadata); + } +}; + +class AttentionKVCacheCuda : private AttentionKVCacheCompute, + public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + do_compute(_op->getInputs()[0], _op->getInputs()[1], + _op->getInputs()[2], _op->getInputs()[3], + _op->getInputs()[4], _op->getInputs()[5], + _op->getOutputs()[0]); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32, + AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32"); +} // namespace infini diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu new file mode 100644 index 00000000..ece6659f --- /dev/null +++ b/src/kernels/cuda/attention_kvcache.cu @@ -0,0 +1,128 @@ +#include "cuda/cuda_common.h" +#include "cuda/cuda_attention_kvcache.h" +#define WARP_SIZE 32 +#define BLOCKSIZE WARP_SIZE +#define SEQ_UNIT 64 + +__global__ void _attention_kvcache_kernel(float* input_k_cache, + float* input_v_cache, + float* input_q, + float* input_k, + float* input_v, + int* position_id, + float* output_matmul, + AttentionKVCacheMetadata compMeta) { + int lane_id = threadIdx.x % WARP_SIZE; + int group_id = threadIdx.x / WARP_SIZE; + int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + + if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) + return; + + float ptr_V[SEQ_UNIT*2]; + float ptr_K[SEQ_UNIT*2]; + float ptr_Q[2]; + float ptr_P[SEQ_UNIT]; + + float ptr_O[2]; + float ptr_max[1]; + float ptr_sum[1]; + + float ptr_max_last[1]; + float ptr_sum_last[1]; + float ptr_O_last[2]; + + (float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)]; + + int SEQ_LENGTH = position_id[0] + 1; + + int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]); + + + for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){ + ptr_max_last[0] = ptr_max[0]; + ptr_sum_last[0] = ptr_sum[0]; + (float2 &)ptr_O_last[0] = (float2 &)ptr_O[0]; + + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ + (float2 &)ptr_K[idx_SEQ_UNIT * 2] + = (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float2 &)ptr_K[idx_SEQ_UNIT * 2] + = (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; + (float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = + (float2 &)ptr_K[idx_SEQ_UNIT * 2]; + } + ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2]; + ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1]; + + #pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset); + } + ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2]; + #pragma unroll + for (int offset = 16; offset > 0; offset /= 2){ + ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset); + } + ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)]; + } + + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); + ptr_P[idx_SEQ_UNIT] /= 8; + ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]); + } + ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]); + + ptr_sum[0] = 0; + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]); + ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; + } + ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0]; + + ptr_O[0] = 0; + ptr_O[1] = 0; + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ + (float2 &)ptr_V[idx_SEQ_UNIT * 2] + = (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float2 &)ptr_V[idx_SEQ_UNIT * 2] + = (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; + (float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = + (float2 &)ptr_V[idx_SEQ_UNIT * 2]; + } + + ptr_P[idx_SEQ_UNIT] /= ptr_sum[0]; + + ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]); + ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]); + } + ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; + ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; + } + (float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0]; +} + +namespace infini { +void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, + float *input_v, int *position_id, float *output_matmul, + const AttentionKVCacheMetadata &compMeta) { + IT_ASSERT(compMeta.dimSize[3] == 64); + dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1); + dim3 blockDim(BLOCKSIZE, 1); + + _attention_kvcache_kernel<<>>( + input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta); +} + +} // namespace infini diff --git a/src/operators/attention_kvcache.cc b/src/operators/attention_kvcache.cc new file mode 100644 index 00000000..9893f509 --- /dev/null +++ b/src/operators/attention_kvcache.cc @@ -0,0 +1,55 @@ +#include "operators/attention_kvcache.h" +#include "utils/operator_utils.h" + +namespace infini { +AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) + : OperatorObj(OpType::AttentionKVCache, + TensorVec{input_k_cache, input_v_cache, input_q, input_k, + input_v, position_id}, + {output_matmul}) { + int rank = inputs[0]->getRank(); + IT_ASSERT(rank == 4); + dim = 2; + IT_ASSERT(checkValid(graph)); +} + +optional> +AttentionKVCacheObj::inferShape(const TensorVec &inputs) const { + IT_ASSERT(inputs.size() == 6); + Shape dims = inputs[0]->getDims(); + ShapeElem n = dims.at(dim); + dims[dim] = n + 1; + return {{inputs[2]->getDims()}}; +} + +std::string AttentionKVCacheObj::toString() const { + std::ostringstream os; + os << "AttentionKVCache[" << getGuid() << "]"; + os << "("; + for (auto input : inputs) + os << vecToString(input->getDims()) << ","; + os << "dim=" << dim << ","; + os << "input="; + for (auto input : inputs) + os << input->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector AttentionKVCacheObj::getWorkloadVector() const { + vector ret = getOutputs()[0]->getDims(); + ret.emplace(ret.begin(), (int)inputs.size()); + ret.emplace(ret.begin(), dim); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector AttentionKVCacheObj::getOpAttrVector() const { + return {type.underlying(), dim}; +} + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_attention.cc b/test/kernels/cuda/test_cuda_attention.cc new file mode 100644 index 00000000..3ccf861d --- /dev/null +++ b/test/kernels/cuda/test_cuda_attention.cc @@ -0,0 +1,42 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/attention_kvcache.h" + +#include "test.h" + +namespace infini { +TEST(AttentionKVCache, Cuda) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + + Graph gCpu = make_ref(runtime); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); + + auto op = gCuda->addOp( + input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d, + position_id_d, nullptr); + gCuda->dataMalloc(); + + input_q_d->setData(OneGenerator()); + input_k_d->setData(OneGenerator()); + input_v_d->setData(OneGenerator()); + position_id_d->setData(IncrementalGenerator()); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData(vector{ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); +} + +} // namespace infini