From afed5d3c3d87841e4275bde9a461cb2d414d73c1 Mon Sep 17 00:00:00 2001 From: xiaonans Date: Thu, 25 Jan 2024 09:08:25 +0800 Subject: [PATCH] use workspace to optimize kvcache attention --- examples/python/llama_kvcache_inference.py | 7 +++---- include/core/graph_handler.h | 7 +++---- include/cuda/cuda_attention_kvcache.h | 1 + include/operators/attention_kvcache.h | 7 ++----- pyinfinitensor/src/pyinfinitensor/onnx.py | 4 +--- src/core/graph_handler.cc | 21 ++++++++++----------- src/core/operator.cc | 6 ++---- src/kernels/cuda/attention_kvcache.cc | 20 +++++++++++--------- src/operators/attention_kvcache.cc | 13 +++++++------ test/kernels/cuda/test_cuda_attention.cc | 2 +- 10 files changed, 41 insertions(+), 47 deletions(-) diff --git a/examples/python/llama_kvcache_inference.py b/examples/python/llama_kvcache_inference.py index b05339b8..e6ba67ff 100644 --- a/examples/python/llama_kvcache_inference.py +++ b/examples/python/llama_kvcache_inference.py @@ -67,16 +67,15 @@ def replace_onnx_with_attention_op(): tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"], tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]] outputs = [ - tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"], - tmap[graph.outputs[1+i*2].name], - tmap[graph.outputs[2+i*2].name]] + tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]] inputs_added = [graph.inputs[1]] outputs_removed = [] graph.replace_with_attention( inputs, outputs, inputs_added, outputs_removed) - + + graph.outputs = [tmap[graph.outputs[0].name]] graph.cleanup(True).toposort() onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 75673d14..0e1472bb 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -74,10 +74,9 @@ class GraphHandlerObj { Tensor squeeze(Tensor input, Tensor output, Shape axes); Tensor unsqueeze(Tensor input, Tensor output, Shape axes); Tensor concat(TensorVec inputs, Tensor output, int dim); - TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, - Tensor input_q, Tensor input_k, Tensor input_v, - Tensor position_id, Tensor output_matmul, - Tensor output_k_cache, Tensor output_v_cache); + Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, + Tensor input_q, Tensor input_k, Tensor input_v, + Tensor position_id, Tensor output_matmul); TensorVec split(Tensor input, std::optional outputs, int axis, std::variant> numOrRatio); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/cuda/cuda_attention_kvcache.h b/include/cuda/cuda_attention_kvcache.h index 74d356c9..91c65d21 100644 --- a/include/cuda/cuda_attention_kvcache.h +++ b/include/cuda/cuda_attention_kvcache.h @@ -1,4 +1,5 @@ #pragma once +#include "core/common.h" #include struct AttentionKVCacheMetadata { diff --git a/include/operators/attention_kvcache.h b/include/operators/attention_kvcache.h index b4448511..0472b222 100644 --- a/include/operators/attention_kvcache.h +++ b/include/operators/attention_kvcache.h @@ -22,21 +22,18 @@ class AttentionKVCacheObj : public OperatorObj { * @param input_v The value input tensor. * @param position_id The positon id of the query, * @param output_matmul The query output tensor. - * @param output_k_cache The output k_cache tensor. - * @param output_v_cache The output v_cache tensor. */ AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul, Tensor output_k_cache, - Tensor output_v_cache); + Tensor output_matmul); OP_CLONE(AttentionKVCacheObj); optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 6; } - int numOutputs() const override { return 3; } + int numOutputs() const override { return 1; } int getDim() const { return dim; } private: diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 5a6f62fc..79abb7f4 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -660,7 +660,7 @@ class OnnxStub: next((attr.i for attr in node.attribute if attr.name == "axis")), ) elif node.op_type == "AttentionKVCache": - tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache( + tensors[node.output[0]] = self.handler.attentionKVCache( tensors[node.input[0]], tensors[node.input[1]], tensors[node.input[2]], @@ -668,8 +668,6 @@ class OnnxStub: tensors[node.input[4]], tensors[node.input[5]], tensors.get(node.output[0]), - tensors.get(node.output[1]), - tensors.get(node.output[2]), ) elif node.op_type == "Split": split = ( diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 18eb893f..cd62ed32 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -324,25 +324,24 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { } } -TensorVec GraphHandlerObj::attentionKVCache( - Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, - Tensor input_v, Tensor position_id, Tensor output_matmul, - Tensor output_k_cache, Tensor output_v_cache) { - if (output_matmul && output_k_cache && output_v_cache) { +Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) { + if (output_matmul) { g->addOpWithOutputs( std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), - std::move(position_id), output_matmul, output_k_cache, - output_v_cache); - return {output_matmul, output_k_cache, output_v_cache}; + std::move(position_id), output_matmul); + return output_matmul; } else { return g ->addOp( std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), - std::move(position_id), output_matmul, output_k_cache, - output_v_cache) - ->getOutputs(); + std::move(position_id), output_matmul) + ->getOutput(); } } diff --git a/src/core/operator.cc b/src/core/operator.cc index 9a7cf6e0..4fd4e6de 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -67,10 +67,8 @@ bool OperatorObj::checkValid(GraphObj *graph) { if (graph) { // if graph != nullptr, outputs should be created auto dataTypes = inferDataType(); for (size_t i = 0; i < outputs.size(); i++) { - if (!outputs[i]) - outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); - else if (shapes[i] != outputs[i]->getDims()) - return false; + IT_ASSERT(!outputs[i], "Find empty output while operator creation"); + outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); } } else { // if outputs have been created, check their shapes for (size_t i = 0; i < shapes.size(); ++i) { diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc index 8ecff414..d72e7838 100644 --- a/src/kernels/cuda/attention_kvcache.cc +++ b/src/kernels/cuda/attention_kvcache.cc @@ -21,8 +21,7 @@ class AttentionKVCacheCompute { public: void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul, Tensor output_temp_O, - Tensor output_temp_sum) const { + Tensor output_matmul, CudaPtr p_workspace) const { AttentionKVCacheMetadata metadata; initAttentionKVCacheMetadata(metadata, input_v_cache); @@ -33,9 +32,8 @@ class AttentionKVCacheCompute { input_v->getRawDataPtr(), position_id->getRawDataPtr(), output_matmul->getRawDataPtr(), - metadata, - output_temp_O->getRawDataPtr(), - output_temp_sum->getRawDataPtr()); + metadata, (float *)p_workspace, + (float *)(p_workspace + (1ll << 30))); } }; @@ -44,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute, void compute(const Operator &_op, const RuntimeObj *_context) const override { IT_ASSERT(_op->getDType() == DataType::Float32); - do_compute( - _op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2], - _op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5], - _op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]); + + size_t workspaceSize = 2ll << 30; + auto context = dynamic_cast(_context); + CudaPtr idxWsData = context->getWorkspace(workspaceSize); + do_compute(_op->getInputs()[0], _op->getInputs()[1], + _op->getInputs()[2], _op->getInputs()[3], + _op->getInputs()[4], _op->getInputs()[5], + _op->getOutputs()[0], idxWsData); } }; diff --git a/src/operators/attention_kvcache.cc b/src/operators/attention_kvcache.cc index 24c3ba2d..492a76f7 100644 --- a/src/operators/attention_kvcache.cc +++ b/src/operators/attention_kvcache.cc @@ -2,14 +2,15 @@ #include "utils/operator_utils.h" namespace infini { -AttentionKVCacheObj::AttentionKVCacheObj( - GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, - Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul, - Tensor output_k_cache, Tensor output_v_cache) +AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) : OperatorObj(OpType::AttentionKVCache, TensorVec{input_k_cache, input_v_cache, input_q, input_k, input_v, position_id}, - TensorVec{output_matmul, output_k_cache, output_v_cache}) { + {output_matmul}) { int rank = inputs[0]->getRank(); IT_ASSERT(rank == 4); dim = 2; @@ -22,7 +23,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) { Shape dims = inputs[0]->getDims(); ShapeElem n = dims.at(dim); dims[dim] = n + 1; - return {{inputs[2]->getDims(), dims, dims}}; + return {{inputs[2]->getDims()}}; } std::string AttentionKVCacheObj::toString() const { diff --git a/test/kernels/cuda/test_cuda_attention.cc b/test/kernels/cuda/test_cuda_attention.cc index b95470f4..3a9bff45 100644 --- a/test/kernels/cuda/test_cuda_attention.cc +++ b/test/kernels/cuda/test_cuda_attention.cc @@ -23,7 +23,7 @@ TEST(AttentionKVCache, Cuda) { auto op = gCuda->addOp( input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d, - position_id_d, nullptr, nullptr, nullptr); + position_id_d, nullptr); gCuda->dataMalloc(); input_q_d->setData(OneGenerator());