use workspace to optimize kvcache attention

This commit is contained in:
xiaonans 2024-01-25 09:08:25 +08:00
parent 6a1bfd6c45
commit afed5d3c3d
10 changed files with 41 additions and 47 deletions

View File

@ -67,9 +67,7 @@ def replace_onnx_with_attention_op():
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"], tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]] tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
outputs = [ outputs = [
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"], tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]]
tmap[graph.outputs[1+i*2].name],
tmap[graph.outputs[2+i*2].name]]
inputs_added = [graph.inputs[1]] inputs_added = [graph.inputs[1]]
outputs_removed = [] outputs_removed = []
@ -77,6 +75,7 @@ def replace_onnx_with_attention_op():
graph.replace_with_attention( graph.replace_with_attention(
inputs, outputs, inputs_added, outputs_removed) inputs, outputs, inputs_added, outputs_removed)
graph.outputs = [tmap[graph.outputs[0].name]]
graph.cleanup(True).toposort() graph.cleanup(True).toposort()
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True) onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True)

View File

@ -74,10 +74,9 @@ class GraphHandlerObj {
Tensor squeeze(Tensor input, Tensor output, Shape axes); Tensor squeeze(Tensor input, Tensor output, Shape axes);
Tensor unsqueeze(Tensor input, Tensor output, Shape axes); Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor concat(TensorVec inputs, Tensor output, int dim);
TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v, Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul, Tensor position_id, Tensor output_matmul);
Tensor output_k_cache, Tensor output_v_cache);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis, TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
std::variant<int, vector<int>> numOrRatio); std::variant<int, vector<int>> numOrRatio);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);

View File

@ -1,4 +1,5 @@
#pragma once #pragma once
#include "core/common.h"
#include <cstdio> #include <cstdio>
struct AttentionKVCacheMetadata { struct AttentionKVCacheMetadata {

View File

@ -22,21 +22,18 @@ class AttentionKVCacheObj : public OperatorObj {
* @param input_v The value input tensor. * @param input_v The value input tensor.
* @param position_id The positon id of the query, * @param position_id The positon id of the query,
* @param output_matmul The query output tensor. * @param output_matmul The query output tensor.
* @param output_k_cache The output k_cache tensor.
* @param output_v_cache The output v_cache tensor.
*/ */
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id, Tensor input_v, Tensor position_id,
Tensor output_matmul, Tensor output_k_cache, Tensor output_matmul);
Tensor output_v_cache);
OP_CLONE(AttentionKVCacheObj); OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 6; } int numInputs() const override { return 6; }
int numOutputs() const override { return 3; } int numOutputs() const override { return 1; }
int getDim() const { return dim; } int getDim() const { return dim; }
private: private:

View File

@ -660,7 +660,7 @@ class OnnxStub:
next((attr.i for attr in node.attribute if attr.name == "axis")), next((attr.i for attr in node.attribute if attr.name == "axis")),
) )
elif node.op_type == "AttentionKVCache": elif node.op_type == "AttentionKVCache":
tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache( tensors[node.output[0]] = self.handler.attentionKVCache(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors[node.input[2]], tensors[node.input[2]],
@ -668,8 +668,6 @@ class OnnxStub:
tensors[node.input[4]], tensors[node.input[4]],
tensors[node.input[5]], tensors[node.input[5]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
tensors.get(node.output[1]),
tensors.get(node.output[2]),
) )
elif node.op_type == "Split": elif node.op_type == "Split":
split = ( split = (

View File

@ -324,25 +324,24 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
} }
} }
TensorVec GraphHandlerObj::attentionKVCache( Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v_cache, Tensor input_q,
Tensor input_v, Tensor position_id, Tensor output_matmul, Tensor input_k, Tensor input_v,
Tensor output_k_cache, Tensor output_v_cache) { Tensor position_id,
if (output_matmul && output_k_cache && output_v_cache) { Tensor output_matmul) {
if (output_matmul) {
g->addOpWithOutputs<AttentionKVCacheObj>( g->addOpWithOutputs<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache), std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v), std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul, output_k_cache, std::move(position_id), output_matmul);
output_v_cache); return output_matmul;
return {output_matmul, output_k_cache, output_v_cache};
} else { } else {
return g return g
->addOp<AttentionKVCacheObj>( ->addOp<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache), std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v), std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul, output_k_cache, std::move(position_id), output_matmul)
output_v_cache) ->getOutput();
->getOutputs();
} }
} }

View File

@ -67,10 +67,8 @@ bool OperatorObj::checkValid(GraphObj *graph) {
if (graph) { // if graph != nullptr, outputs should be created if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = inferDataType(); auto dataTypes = inferDataType();
for (size_t i = 0; i < outputs.size(); i++) { for (size_t i = 0; i < outputs.size(); i++) {
if (!outputs[i]) IT_ASSERT(!outputs[i], "Find empty output while operator creation");
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
else if (shapes[i] != outputs[i]->getDims())
return false;
} }
} else { // if outputs have been created, check their shapes } else { // if outputs have been created, check their shapes
for (size_t i = 0; i < shapes.size(); ++i) { for (size_t i = 0; i < shapes.size(); ++i) {

View File

@ -21,8 +21,7 @@ class AttentionKVCacheCompute {
public: public:
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id, Tensor input_k, Tensor input_v, Tensor position_id,
Tensor output_matmul, Tensor output_temp_O, Tensor output_matmul, CudaPtr p_workspace) const {
Tensor output_temp_sum) const {
AttentionKVCacheMetadata metadata; AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache); initAttentionKVCacheMetadata(metadata, input_v_cache);
@ -33,9 +32,8 @@ class AttentionKVCacheCompute {
input_v->getRawDataPtr<float *>(), input_v->getRawDataPtr<float *>(),
position_id->getRawDataPtr<int *>(), position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<float *>(), output_matmul->getRawDataPtr<float *>(),
metadata, metadata, (float *)p_workspace,
output_temp_O->getRawDataPtr<float *>(), (float *)(p_workspace + (1ll << 30)));
output_temp_sum->getRawDataPtr<float *>());
} }
}; };
@ -44,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
IT_ASSERT(_op->getDType() == DataType::Float32); IT_ASSERT(_op->getDType() == DataType::Float32);
do_compute(
_op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2], size_t workspaceSize = 2ll << 30;
_op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5], auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
_op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]); CudaPtr idxWsData = context->getWorkspace(workspaceSize);
do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5],
_op->getOutputs()[0], idxWsData);
} }
}; };

View File

@ -2,14 +2,15 @@
#include "utils/operator_utils.h" #include "utils/operator_utils.h"
namespace infini { namespace infini {
AttentionKVCacheObj::AttentionKVCacheObj( AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul, Tensor input_k, Tensor input_v,
Tensor output_k_cache, Tensor output_v_cache) Tensor position_id,
Tensor output_matmul)
: OperatorObj(OpType::AttentionKVCache, : OperatorObj(OpType::AttentionKVCache,
TensorVec{input_k_cache, input_v_cache, input_q, input_k, TensorVec{input_k_cache, input_v_cache, input_q, input_k,
input_v, position_id}, input_v, position_id},
TensorVec{output_matmul, output_k_cache, output_v_cache}) { {output_matmul}) {
int rank = inputs[0]->getRank(); int rank = inputs[0]->getRank();
IT_ASSERT(rank == 4); IT_ASSERT(rank == 4);
dim = 2; dim = 2;
@ -22,7 +23,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) {
Shape dims = inputs[0]->getDims(); Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim); ShapeElem n = dims.at(dim);
dims[dim] = n + 1; dims[dim] = n + 1;
return {{inputs[2]->getDims(), dims, dims}}; return {{inputs[2]->getDims()}};
} }
std::string AttentionKVCacheObj::toString() const { std::string AttentionKVCacheObj::toString() const {

View File

@ -23,7 +23,7 @@ TEST(AttentionKVCache, Cuda) {
auto op = gCuda->addOp<AttentionKVCacheObj>( auto op = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d, input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
position_id_d, nullptr, nullptr, nullptr); position_id_d, nullptr);
gCuda->dataMalloc(); gCuda->dataMalloc();
input_q_d->setData(OneGenerator()); input_q_d->setData(OneGenerator());