forked from jiuyuan/InfiniTensor
use workspace to optimize kvcache attention
This commit is contained in:
parent
6a1bfd6c45
commit
afed5d3c3d
|
@ -67,16 +67,15 @@ def replace_onnx_with_attention_op():
|
||||||
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
|
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
|
||||||
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
|
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
|
||||||
outputs = [
|
outputs = [
|
||||||
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],
|
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]]
|
||||||
tmap[graph.outputs[1+i*2].name],
|
|
||||||
tmap[graph.outputs[2+i*2].name]]
|
|
||||||
|
|
||||||
inputs_added = [graph.inputs[1]]
|
inputs_added = [graph.inputs[1]]
|
||||||
outputs_removed = []
|
outputs_removed = []
|
||||||
|
|
||||||
graph.replace_with_attention(
|
graph.replace_with_attention(
|
||||||
inputs, outputs, inputs_added, outputs_removed)
|
inputs, outputs, inputs_added, outputs_removed)
|
||||||
|
|
||||||
|
graph.outputs = [tmap[graph.outputs[0].name]]
|
||||||
graph.cleanup(True).toposort()
|
graph.cleanup(True).toposort()
|
||||||
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True)
|
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True)
|
||||||
|
|
||||||
|
|
|
@ -74,10 +74,9 @@ class GraphHandlerObj {
|
||||||
Tensor squeeze(Tensor input, Tensor output, Shape axes);
|
Tensor squeeze(Tensor input, Tensor output, Shape axes);
|
||||||
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
|
Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
|
||||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||||
TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||||
Tensor position_id, Tensor output_matmul,
|
Tensor position_id, Tensor output_matmul);
|
||||||
Tensor output_k_cache, Tensor output_v_cache);
|
|
||||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||||
std::variant<int, vector<int>> numOrRatio);
|
std::variant<int, vector<int>> numOrRatio);
|
||||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
struct AttentionKVCacheMetadata {
|
struct AttentionKVCacheMetadata {
|
||||||
|
|
|
@ -22,21 +22,18 @@ class AttentionKVCacheObj : public OperatorObj {
|
||||||
* @param input_v The value input tensor.
|
* @param input_v The value input tensor.
|
||||||
* @param position_id The positon id of the query,
|
* @param position_id The positon id of the query,
|
||||||
* @param output_matmul The query output tensor.
|
* @param output_matmul The query output tensor.
|
||||||
* @param output_k_cache The output k_cache tensor.
|
|
||||||
* @param output_v_cache The output v_cache tensor.
|
|
||||||
*/
|
*/
|
||||||
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||||
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||||
Tensor input_v, Tensor position_id,
|
Tensor input_v, Tensor position_id,
|
||||||
Tensor output_matmul, Tensor output_k_cache,
|
Tensor output_matmul);
|
||||||
Tensor output_v_cache);
|
|
||||||
OP_CLONE(AttentionKVCacheObj);
|
OP_CLONE(AttentionKVCacheObj);
|
||||||
|
|
||||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
std::string toString() const override;
|
std::string toString() const override;
|
||||||
int numInputs() const override { return 6; }
|
int numInputs() const override { return 6; }
|
||||||
int numOutputs() const override { return 3; }
|
int numOutputs() const override { return 1; }
|
||||||
int getDim() const { return dim; }
|
int getDim() const { return dim; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -660,7 +660,7 @@ class OnnxStub:
|
||||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||||
)
|
)
|
||||||
elif node.op_type == "AttentionKVCache":
|
elif node.op_type == "AttentionKVCache":
|
||||||
tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache(
|
tensors[node.output[0]] = self.handler.attentionKVCache(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors[node.input[2]],
|
tensors[node.input[2]],
|
||||||
|
@ -668,8 +668,6 @@ class OnnxStub:
|
||||||
tensors[node.input[4]],
|
tensors[node.input[4]],
|
||||||
tensors[node.input[5]],
|
tensors[node.input[5]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
tensors.get(node.output[1]),
|
|
||||||
tensors.get(node.output[2]),
|
|
||||||
)
|
)
|
||||||
elif node.op_type == "Split":
|
elif node.op_type == "Split":
|
||||||
split = (
|
split = (
|
||||||
|
|
|
@ -324,25 +324,24 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorVec GraphHandlerObj::attentionKVCache(
|
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
||||||
Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
Tensor input_v_cache, Tensor input_q,
|
||||||
Tensor input_v, Tensor position_id, Tensor output_matmul,
|
Tensor input_k, Tensor input_v,
|
||||||
Tensor output_k_cache, Tensor output_v_cache) {
|
Tensor position_id,
|
||||||
if (output_matmul && output_k_cache && output_v_cache) {
|
Tensor output_matmul) {
|
||||||
|
if (output_matmul) {
|
||||||
g->addOpWithOutputs<AttentionKVCacheObj>(
|
g->addOpWithOutputs<AttentionKVCacheObj>(
|
||||||
std::move(input_k_cache), std::move(input_v_cache),
|
std::move(input_k_cache), std::move(input_v_cache),
|
||||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||||
std::move(position_id), output_matmul, output_k_cache,
|
std::move(position_id), output_matmul);
|
||||||
output_v_cache);
|
return output_matmul;
|
||||||
return {output_matmul, output_k_cache, output_v_cache};
|
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<AttentionKVCacheObj>(
|
->addOp<AttentionKVCacheObj>(
|
||||||
std::move(input_k_cache), std::move(input_v_cache),
|
std::move(input_k_cache), std::move(input_v_cache),
|
||||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||||
std::move(position_id), output_matmul, output_k_cache,
|
std::move(position_id), output_matmul)
|
||||||
output_v_cache)
|
->getOutput();
|
||||||
->getOutputs();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -67,10 +67,8 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
||||||
if (graph) { // if graph != nullptr, outputs should be created
|
if (graph) { // if graph != nullptr, outputs should be created
|
||||||
auto dataTypes = inferDataType();
|
auto dataTypes = inferDataType();
|
||||||
for (size_t i = 0; i < outputs.size(); i++) {
|
for (size_t i = 0; i < outputs.size(); i++) {
|
||||||
if (!outputs[i])
|
IT_ASSERT(!outputs[i], "Find empty output while operator creation");
|
||||||
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
||||||
else if (shapes[i] != outputs[i]->getDims())
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
} else { // if outputs have been created, check their shapes
|
} else { // if outputs have been created, check their shapes
|
||||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||||
|
|
|
@ -21,8 +21,7 @@ class AttentionKVCacheCompute {
|
||||||
public:
|
public:
|
||||||
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||||
Tensor input_k, Tensor input_v, Tensor position_id,
|
Tensor input_k, Tensor input_v, Tensor position_id,
|
||||||
Tensor output_matmul, Tensor output_temp_O,
|
Tensor output_matmul, CudaPtr p_workspace) const {
|
||||||
Tensor output_temp_sum) const {
|
|
||||||
AttentionKVCacheMetadata metadata;
|
AttentionKVCacheMetadata metadata;
|
||||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||||
|
|
||||||
|
@ -33,9 +32,8 @@ class AttentionKVCacheCompute {
|
||||||
input_v->getRawDataPtr<float *>(),
|
input_v->getRawDataPtr<float *>(),
|
||||||
position_id->getRawDataPtr<int *>(),
|
position_id->getRawDataPtr<int *>(),
|
||||||
output_matmul->getRawDataPtr<float *>(),
|
output_matmul->getRawDataPtr<float *>(),
|
||||||
metadata,
|
metadata, (float *)p_workspace,
|
||||||
output_temp_O->getRawDataPtr<float *>(),
|
(float *)(p_workspace + (1ll << 30)));
|
||||||
output_temp_sum->getRawDataPtr<float *>());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -44,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
IT_ASSERT(_op->getDType() == DataType::Float32);
|
IT_ASSERT(_op->getDType() == DataType::Float32);
|
||||||
do_compute(
|
|
||||||
_op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2],
|
size_t workspaceSize = 2ll << 30;
|
||||||
_op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5],
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
_op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]);
|
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
|
||||||
|
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
||||||
|
_op->getInputs()[2], _op->getInputs()[3],
|
||||||
|
_op->getInputs()[4], _op->getInputs()[5],
|
||||||
|
_op->getOutputs()[0], idxWsData);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -2,14 +2,15 @@
|
||||||
#include "utils/operator_utils.h"
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
AttentionKVCacheObj::AttentionKVCacheObj(
|
AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||||
GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
Tensor input_v_cache, Tensor input_q,
|
||||||
Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul,
|
Tensor input_k, Tensor input_v,
|
||||||
Tensor output_k_cache, Tensor output_v_cache)
|
Tensor position_id,
|
||||||
|
Tensor output_matmul)
|
||||||
: OperatorObj(OpType::AttentionKVCache,
|
: OperatorObj(OpType::AttentionKVCache,
|
||||||
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
|
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
|
||||||
input_v, position_id},
|
input_v, position_id},
|
||||||
TensorVec{output_matmul, output_k_cache, output_v_cache}) {
|
{output_matmul}) {
|
||||||
int rank = inputs[0]->getRank();
|
int rank = inputs[0]->getRank();
|
||||||
IT_ASSERT(rank == 4);
|
IT_ASSERT(rank == 4);
|
||||||
dim = 2;
|
dim = 2;
|
||||||
|
@ -22,7 +23,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) {
|
||||||
Shape dims = inputs[0]->getDims();
|
Shape dims = inputs[0]->getDims();
|
||||||
ShapeElem n = dims.at(dim);
|
ShapeElem n = dims.at(dim);
|
||||||
dims[dim] = n + 1;
|
dims[dim] = n + 1;
|
||||||
return {{inputs[2]->getDims(), dims, dims}};
|
return {{inputs[2]->getDims()}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AttentionKVCacheObj::toString() const {
|
std::string AttentionKVCacheObj::toString() const {
|
||||||
|
|
|
@ -23,7 +23,7 @@ TEST(AttentionKVCache, Cuda) {
|
||||||
|
|
||||||
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
||||||
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
|
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
|
||||||
position_id_d, nullptr, nullptr, nullptr);
|
position_id_d, nullptr);
|
||||||
gCuda->dataMalloc();
|
gCuda->dataMalloc();
|
||||||
|
|
||||||
input_q_d->setData(OneGenerator());
|
input_q_d->setData(OneGenerator());
|
||||||
|
|
Loading…
Reference in New Issue