use workspace to optimize kvcache attention

This commit is contained in:
xiaonans 2024-01-25 09:08:25 +08:00
parent 6a1bfd6c45
commit afed5d3c3d
10 changed files with 41 additions and 47 deletions

View File

@ -67,16 +67,15 @@ def replace_onnx_with_attention_op():
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
outputs = [
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],
tmap[graph.outputs[1+i*2].name],
tmap[graph.outputs[2+i*2].name]]
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]]
inputs_added = [graph.inputs[1]]
outputs_removed = []
graph.replace_with_attention(
inputs, outputs, inputs_added, outputs_removed)
graph.outputs = [tmap[graph.outputs[0].name]]
graph.cleanup(True).toposort()
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True)

View File

@ -74,10 +74,9 @@ class GraphHandlerObj {
Tensor squeeze(Tensor input, Tensor output, Shape axes);
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
Tensor concat(TensorVec inputs, Tensor output, int dim);
TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul,
Tensor output_k_cache, Tensor output_v_cache);
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
std::variant<int, vector<int>> numOrRatio);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);

View File

@ -1,4 +1,5 @@
#pragma once
#include "core/common.h"
#include <cstdio>
struct AttentionKVCacheMetadata {

View File

@ -22,21 +22,18 @@ class AttentionKVCacheObj : public OperatorObj {
* @param input_v The value input tensor.
* @param position_id The positon id of the query,
* @param output_matmul The query output tensor.
* @param output_k_cache The output k_cache tensor.
* @param output_v_cache The output v_cache tensor.
*/
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id,
Tensor output_matmul, Tensor output_k_cache,
Tensor output_v_cache);
Tensor output_matmul);
OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 6; }
int numOutputs() const override { return 3; }
int numOutputs() const override { return 1; }
int getDim() const { return dim; }
private:

View File

@ -660,7 +660,7 @@ class OnnxStub:
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "AttentionKVCache":
tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache(
tensors[node.output[0]] = self.handler.attentionKVCache(
tensors[node.input[0]],
tensors[node.input[1]],
tensors[node.input[2]],
@ -668,8 +668,6 @@ class OnnxStub:
tensors[node.input[4]],
tensors[node.input[5]],
tensors.get(node.output[0]),
tensors.get(node.output[1]),
tensors.get(node.output[2]),
)
elif node.op_type == "Split":
split = (

View File

@ -324,25 +324,24 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
}
}
TensorVec GraphHandlerObj::attentionKVCache(
Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id, Tensor output_matmul,
Tensor output_k_cache, Tensor output_v_cache) {
if (output_matmul && output_k_cache && output_v_cache) {
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v,
Tensor position_id,
Tensor output_matmul) {
if (output_matmul) {
g->addOpWithOutputs<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul, output_k_cache,
output_v_cache);
return {output_matmul, output_k_cache, output_v_cache};
std::move(position_id), output_matmul);
return output_matmul;
} else {
return g
->addOp<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul, output_k_cache,
output_v_cache)
->getOutputs();
std::move(position_id), output_matmul)
->getOutput();
}
}

View File

@ -67,10 +67,8 @@ bool OperatorObj::checkValid(GraphObj *graph) {
if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = inferDataType();
for (size_t i = 0; i < outputs.size(); i++) {
if (!outputs[i])
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
else if (shapes[i] != outputs[i]->getDims())
return false;
IT_ASSERT(!outputs[i], "Find empty output while operator creation");
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
}
} else { // if outputs have been created, check their shapes
for (size_t i = 0; i < shapes.size(); ++i) {

View File

@ -21,8 +21,7 @@ class AttentionKVCacheCompute {
public:
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id,
Tensor output_matmul, Tensor output_temp_O,
Tensor output_temp_sum) const {
Tensor output_matmul, CudaPtr p_workspace) const {
AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache);
@ -33,9 +32,8 @@ class AttentionKVCacheCompute {
input_v->getRawDataPtr<float *>(),
position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<float *>(),
metadata,
output_temp_O->getRawDataPtr<float *>(),
output_temp_sum->getRawDataPtr<float *>());
metadata, (float *)p_workspace,
(float *)(p_workspace + (1ll << 30)));
}
};
@ -44,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
IT_ASSERT(_op->getDType() == DataType::Float32);
do_compute(
_op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2],
_op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5],
_op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]);
size_t workspaceSize = 2ll << 30;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
do_compute(_op->getInputs()[0], _op->getInputs()[1],
_op->getInputs()[2], _op->getInputs()[3],
_op->getInputs()[4], _op->getInputs()[5],
_op->getOutputs()[0], idxWsData);
}
};

View File

@ -2,14 +2,15 @@
#include "utils/operator_utils.h"
namespace infini {
AttentionKVCacheObj::AttentionKVCacheObj(
GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul,
Tensor output_k_cache, Tensor output_v_cache)
AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v,
Tensor position_id,
Tensor output_matmul)
: OperatorObj(OpType::AttentionKVCache,
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
input_v, position_id},
TensorVec{output_matmul, output_k_cache, output_v_cache}) {
{output_matmul}) {
int rank = inputs[0]->getRank();
IT_ASSERT(rank == 4);
dim = 2;
@ -22,7 +23,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) {
Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim);
dims[dim] = n + 1;
return {{inputs[2]->getDims(), dims, dims}};
return {{inputs[2]->getDims()}};
}
std::string AttentionKVCacheObj::toString() const {

View File

@ -23,7 +23,7 @@ TEST(AttentionKVCache, Cuda) {
auto op = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
position_id_d, nullptr, nullptr, nullptr);
position_id_d, nullptr);
gCuda->dataMalloc();
input_q_d->setData(OneGenerator());