forked from jiuyuan/InfiniTensor
use workspace to optimize kvcache attention
This commit is contained in:
parent
6a1bfd6c45
commit
afed5d3c3d
|
@ -67,9 +67,7 @@ def replace_onnx_with_attention_op():
|
|||
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
|
||||
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
|
||||
outputs = [
|
||||
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],
|
||||
tmap[graph.outputs[1+i*2].name],
|
||||
tmap[graph.outputs[2+i*2].name]]
|
||||
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]]
|
||||
|
||||
inputs_added = [graph.inputs[1]]
|
||||
outputs_removed = []
|
||||
|
@ -77,6 +75,7 @@ def replace_onnx_with_attention_op():
|
|||
graph.replace_with_attention(
|
||||
inputs, outputs, inputs_added, outputs_removed)
|
||||
|
||||
graph.outputs = [tmap[graph.outputs[0].name]]
|
||||
graph.cleanup(True).toposort()
|
||||
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True)
|
||||
|
||||
|
|
|
@ -74,10 +74,9 @@ class GraphHandlerObj {
|
|||
Tensor squeeze(Tensor input, Tensor output, Shape axes);
|
||||
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
|
||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||
TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||
Tensor position_id, Tensor output_matmul,
|
||||
Tensor output_k_cache, Tensor output_v_cache);
|
||||
Tensor position_id, Tensor output_matmul);
|
||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||
std::variant<int, vector<int>> numOrRatio);
|
||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
#include <cstdio>
|
||||
|
||||
struct AttentionKVCacheMetadata {
|
||||
|
|
|
@ -22,21 +22,18 @@ class AttentionKVCacheObj : public OperatorObj {
|
|||
* @param input_v The value input tensor.
|
||||
* @param position_id The positon id of the query,
|
||||
* @param output_matmul The query output tensor.
|
||||
* @param output_k_cache The output k_cache tensor.
|
||||
* @param output_v_cache The output v_cache tensor.
|
||||
*/
|
||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||
Tensor input_v, Tensor position_id,
|
||||
Tensor output_matmul, Tensor output_k_cache,
|
||||
Tensor output_v_cache);
|
||||
Tensor output_matmul);
|
||||
OP_CLONE(AttentionKVCacheObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 6; }
|
||||
int numOutputs() const override { return 3; }
|
||||
int numOutputs() const override { return 1; }
|
||||
int getDim() const { return dim; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -660,7 +660,7 @@ class OnnxStub:
|
|||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "AttentionKVCache":
|
||||
tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache(
|
||||
tensors[node.output[0]] = self.handler.attentionKVCache(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors[node.input[2]],
|
||||
|
@ -668,8 +668,6 @@ class OnnxStub:
|
|||
tensors[node.input[4]],
|
||||
tensors[node.input[5]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1]),
|
||||
tensors.get(node.output[2]),
|
||||
)
|
||||
elif node.op_type == "Split":
|
||||
split = (
|
||||
|
|
|
@ -324,25 +324,24 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
|||
}
|
||||
}
|
||||
|
||||
TensorVec GraphHandlerObj::attentionKVCache(
|
||||
Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||
Tensor input_v, Tensor position_id, Tensor output_matmul,
|
||||
Tensor output_k_cache, Tensor output_v_cache) {
|
||||
if (output_matmul && output_k_cache && output_v_cache) {
|
||||
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v,
|
||||
Tensor position_id,
|
||||
Tensor output_matmul) {
|
||||
if (output_matmul) {
|
||||
g->addOpWithOutputs<AttentionKVCacheObj>(
|
||||
std::move(input_k_cache), std::move(input_v_cache),
|
||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||
std::move(position_id), output_matmul, output_k_cache,
|
||||
output_v_cache);
|
||||
return {output_matmul, output_k_cache, output_v_cache};
|
||||
std::move(position_id), output_matmul);
|
||||
return output_matmul;
|
||||
} else {
|
||||
return g
|
||||
->addOp<AttentionKVCacheObj>(
|
||||
std::move(input_k_cache), std::move(input_v_cache),
|
||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||
std::move(position_id), output_matmul, output_k_cache,
|
||||
output_v_cache)
|
||||
->getOutputs();
|
||||
std::move(position_id), output_matmul)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -67,10 +67,8 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
|||
if (graph) { // if graph != nullptr, outputs should be created
|
||||
auto dataTypes = inferDataType();
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
if (!outputs[i])
|
||||
IT_ASSERT(!outputs[i], "Find empty output while operator creation");
|
||||
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
||||
else if (shapes[i] != outputs[i]->getDims())
|
||||
return false;
|
||||
}
|
||||
} else { // if outputs have been created, check their shapes
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
|
|
|
@ -21,8 +21,7 @@ class AttentionKVCacheCompute {
|
|||
public:
|
||||
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v, Tensor position_id,
|
||||
Tensor output_matmul, Tensor output_temp_O,
|
||||
Tensor output_temp_sum) const {
|
||||
Tensor output_matmul, CudaPtr p_workspace) const {
|
||||
AttentionKVCacheMetadata metadata;
|
||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||
|
||||
|
@ -33,9 +32,8 @@ class AttentionKVCacheCompute {
|
|||
input_v->getRawDataPtr<float *>(),
|
||||
position_id->getRawDataPtr<int *>(),
|
||||
output_matmul->getRawDataPtr<float *>(),
|
||||
metadata,
|
||||
output_temp_O->getRawDataPtr<float *>(),
|
||||
output_temp_sum->getRawDataPtr<float *>());
|
||||
metadata, (float *)p_workspace,
|
||||
(float *)(p_workspace + (1ll << 30)));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -44,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
IT_ASSERT(_op->getDType() == DataType::Float32);
|
||||
do_compute(
|
||||
_op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2],
|
||||
_op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5],
|
||||
_op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]);
|
||||
|
||||
size_t workspaceSize = 2ll << 30;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
|
||||
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
||||
_op->getInputs()[2], _op->getInputs()[3],
|
||||
_op->getInputs()[4], _op->getInputs()[5],
|
||||
_op->getOutputs()[0], idxWsData);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -2,14 +2,15 @@
|
|||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
AttentionKVCacheObj::AttentionKVCacheObj(
|
||||
GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul,
|
||||
Tensor output_k_cache, Tensor output_v_cache)
|
||||
AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||
Tensor input_v_cache, Tensor input_q,
|
||||
Tensor input_k, Tensor input_v,
|
||||
Tensor position_id,
|
||||
Tensor output_matmul)
|
||||
: OperatorObj(OpType::AttentionKVCache,
|
||||
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
|
||||
input_v, position_id},
|
||||
TensorVec{output_matmul, output_k_cache, output_v_cache}) {
|
||||
{output_matmul}) {
|
||||
int rank = inputs[0]->getRank();
|
||||
IT_ASSERT(rank == 4);
|
||||
dim = 2;
|
||||
|
@ -22,7 +23,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) {
|
|||
Shape dims = inputs[0]->getDims();
|
||||
ShapeElem n = dims.at(dim);
|
||||
dims[dim] = n + 1;
|
||||
return {{inputs[2]->getDims(), dims, dims}};
|
||||
return {{inputs[2]->getDims()}};
|
||||
}
|
||||
|
||||
std::string AttentionKVCacheObj::toString() const {
|
||||
|
|
|
@ -23,7 +23,7 @@ TEST(AttentionKVCache, Cuda) {
|
|||
|
||||
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
||||
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
|
||||
position_id_d, nullptr, nullptr, nullptr);
|
||||
position_id_d, nullptr);
|
||||
gCuda->dataMalloc();
|
||||
|
||||
input_q_d->setData(OneGenerator());
|
||||
|
|
Loading…
Reference in New Issue