forked from jiuyuan/InfiniTensor
[feature] add fused attention_kvcache operator support (#179)
* [feature] add fused attention_kvcache operator support * add test to attention_kvcache op * Add space line at EOF --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
f22fa2766e
commit
965df4e294
|
@ -64,6 +64,9 @@ class GraphHandlerObj {
|
||||||
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
|
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
|
||||||
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||||
|
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||||
|
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||||
|
Tensor position_id, Tensor output_matmul);
|
||||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||||
int num_outputs);
|
int num_outputs);
|
||||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||||
|
|
|
@ -25,6 +25,7 @@ struct OpType {
|
||||||
Asinh, // Unary
|
Asinh, // Unary
|
||||||
Atan, // Unary
|
Atan, // Unary
|
||||||
Atanh, // Unary
|
Atanh, // Unary
|
||||||
|
AttentionKVCache, // Fusion
|
||||||
AveragePool, // Pool
|
AveragePool, // Pool
|
||||||
BatchNormalization, //
|
BatchNormalization, //
|
||||||
Bernoulli, //
|
Bernoulli, //
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
#pragma once
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
struct AttentionKVCacheMetadata {
|
||||||
|
int dimSize[4];
|
||||||
|
int stride[4];
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||||
|
float *input_q, float *input_k, float *input_v,
|
||||||
|
int *position_id, float *output_matmul,
|
||||||
|
const AttentionKVCacheMetadata &compMeta);
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,43 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
* @brief Fused Attention with KVCache input operator. All the input and output
|
||||||
|
* tensors should have the same rank except for the position_id.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class AttentionKVCacheObj : public OperatorObj {
|
||||||
|
int dim;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new AttentionKVCache object.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input_k_cache The k_cache input tensor.
|
||||||
|
* @param input_v_cache The v_cache input tensor.
|
||||||
|
* @param input_q The query input tensor.
|
||||||
|
* @param input_k The key input tensor.
|
||||||
|
* @param input_v The value input tensor.
|
||||||
|
* @param position_id The positon id of the query,
|
||||||
|
* @param output_matmul The query output tensor.
|
||||||
|
*/
|
||||||
|
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||||
|
Tensor input_v_cache, Tensor input_q, Tensor input_k,
|
||||||
|
Tensor input_v, Tensor position_id,
|
||||||
|
Tensor output_matmul);
|
||||||
|
OP_CLONE(AttentionKVCacheObj);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 6; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
int getDim() const { return dim; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -46,6 +46,9 @@ class OnnxStub:
|
||||||
model = model_simp
|
model = model_simp
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
pass
|
pass
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
self.inputs: Dict[str, backend.Tensor] = {}
|
self.inputs: Dict[str, backend.Tensor] = {}
|
||||||
self.outputs: Dict[str, backend.Tensor] = {}
|
self.outputs: Dict[str, backend.Tensor] = {}
|
||||||
self.initializer: Dict[int, TensorProto] = {}
|
self.initializer: Dict[int, TensorProto] = {}
|
||||||
|
@ -560,6 +563,16 @@ class OnnxStub:
|
||||||
(attr.i for attr in node.attribute if attr.name == "axis")
|
(attr.i for attr in node.attribute if attr.name == "axis")
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "AttentionKVCache":
|
||||||
|
tensors[node.output[0]] = self.handler.attentionKVCache(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors[node.input[2]],
|
||||||
|
tensors[node.input[3]],
|
||||||
|
tensors[node.input[4]],
|
||||||
|
tensors[node.input[5]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
elif node.op_type == "Split":
|
elif node.op_type == "Split":
|
||||||
for name, tensor in zip(
|
for name, tensor in zip(
|
||||||
node.output,
|
node.output,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/all_gather.h"
|
#include "operators/all_gather.h"
|
||||||
#include "operators/all_reduce.h"
|
#include "operators/all_reduce.h"
|
||||||
|
#include "operators/attention_kvcache.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/broadcast.h"
|
#include "operators/broadcast.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
|
@ -239,6 +240,27 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
||||||
|
Tensor input_v_cache, Tensor input_q,
|
||||||
|
Tensor input_k, Tensor input_v,
|
||||||
|
Tensor position_id,
|
||||||
|
Tensor output_matmul) {
|
||||||
|
if (output_matmul) {
|
||||||
|
g->addOpWithOutputs<AttentionKVCacheObj>(
|
||||||
|
std::move(input_k_cache), std::move(input_v_cache),
|
||||||
|
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||||
|
std::move(position_id), output_matmul);
|
||||||
|
return {output_matmul};
|
||||||
|
} else {
|
||||||
|
return g
|
||||||
|
->addOp<AttentionKVCacheObj>(
|
||||||
|
std::move(input_k_cache), std::move(input_v_cache),
|
||||||
|
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||||
|
std::move(position_id), output_matmul)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||||
int axis, int num_outputs) {
|
int axis, int num_outputs) {
|
||||||
if (outputs) {
|
if (outputs) {
|
||||||
|
|
|
@ -489,6 +489,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("depthToSpace", &Handler::depthToSpace, policy::move)
|
.def("depthToSpace", &Handler::depthToSpace, policy::move)
|
||||||
.def("reshape", &Handler::reshape, policy::move)
|
.def("reshape", &Handler::reshape, policy::move)
|
||||||
.def("concat", &Handler::concat, policy::move)
|
.def("concat", &Handler::concat, policy::move)
|
||||||
|
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
|
||||||
.def("split", &Handler::split, policy::move)
|
.def("split", &Handler::split, policy::move)
|
||||||
.def("gather", &Handler::gather, policy::move)
|
.def("gather", &Handler::gather, policy::move)
|
||||||
.def("gatherElements", &Handler::gatherElements, policy::move)
|
.def("gatherElements", &Handler::gatherElements, policy::move)
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
#include "operators/attention_kvcache.h"
|
||||||
|
#include "cuda/cuda_attention_kvcache.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class AttentionKVCacheCompute {
|
||||||
|
void initAttentionKVCacheMetadata(AttentionKVCacheMetadata &metadata,
|
||||||
|
Tensor tensor) const {
|
||||||
|
int nDims = tensor->getRank();
|
||||||
|
auto strides = tensor->getStride();
|
||||||
|
IT_ASSERT(nDims == 4);
|
||||||
|
IT_ASSERT(strides.size() == (size_t)nDims);
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
metadata.dimSize[i] = tensor->getDims().at(i);
|
||||||
|
metadata.stride[i] = strides.at(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||||
|
Tensor input_k, Tensor input_v, Tensor position_id,
|
||||||
|
Tensor output_matmul) const {
|
||||||
|
AttentionKVCacheMetadata metadata;
|
||||||
|
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||||
|
|
||||||
|
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
|
||||||
|
input_v_cache->getRawDataPtr<float *>(),
|
||||||
|
input_q->getRawDataPtr<float *>(),
|
||||||
|
input_k->getRawDataPtr<float *>(),
|
||||||
|
input_v->getRawDataPtr<float *>(),
|
||||||
|
position_id->getRawDataPtr<int *>(),
|
||||||
|
output_matmul->getRawDataPtr<float *>(),
|
||||||
|
metadata);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
||||||
|
public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
||||||
|
_op->getInputs()[2], _op->getInputs()[3],
|
||||||
|
_op->getInputs()[4], _op->getInputs()[5],
|
||||||
|
_op->getOutputs()[0]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AttentionKVCache, DataType::Float32,
|
||||||
|
AttentionKVCacheCuda, "AttentionKVCache_CUDA_Float32");
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,128 @@
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
#include "cuda/cuda_attention_kvcache.h"
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
#define BLOCKSIZE WARP_SIZE
|
||||||
|
#define SEQ_UNIT 64
|
||||||
|
|
||||||
|
__global__ void _attention_kvcache_kernel(float* input_k_cache,
|
||||||
|
float* input_v_cache,
|
||||||
|
float* input_q,
|
||||||
|
float* input_k,
|
||||||
|
float* input_v,
|
||||||
|
int* position_id,
|
||||||
|
float* output_matmul,
|
||||||
|
AttentionKVCacheMetadata compMeta) {
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
int group_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||||
|
|
||||||
|
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
||||||
|
return;
|
||||||
|
|
||||||
|
float ptr_V[SEQ_UNIT*2];
|
||||||
|
float ptr_K[SEQ_UNIT*2];
|
||||||
|
float ptr_Q[2];
|
||||||
|
float ptr_P[SEQ_UNIT];
|
||||||
|
|
||||||
|
float ptr_O[2];
|
||||||
|
float ptr_max[1];
|
||||||
|
float ptr_sum[1];
|
||||||
|
|
||||||
|
float ptr_max_last[1];
|
||||||
|
float ptr_sum_last[1];
|
||||||
|
float ptr_O_last[2];
|
||||||
|
|
||||||
|
(float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)];
|
||||||
|
|
||||||
|
int SEQ_LENGTH = position_id[0] + 1;
|
||||||
|
|
||||||
|
int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]);
|
||||||
|
|
||||||
|
|
||||||
|
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
|
||||||
|
ptr_max_last[0] = ptr_max[0];
|
||||||
|
ptr_sum_last[0] = ptr_sum[0];
|
||||||
|
(float2 &)ptr_O_last[0] = (float2 &)ptr_O[0];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||||
|
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
|
||||||
|
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
|
||||||
|
= (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
|
||||||
|
= (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
|
||||||
|
(float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
||||||
|
(float2 &)ptr_K[idx_SEQ_UNIT * 2];
|
||||||
|
}
|
||||||
|
ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2];
|
||||||
|
ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset /= 2) {
|
||||||
|
ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset);
|
||||||
|
}
|
||||||
|
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset /= 2){
|
||||||
|
ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset);
|
||||||
|
}
|
||||||
|
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||||
|
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
||||||
|
ptr_P[idx_SEQ_UNIT] /= 8;
|
||||||
|
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
|
||||||
|
}
|
||||||
|
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
|
||||||
|
|
||||||
|
ptr_sum[0] = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||||
|
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]);
|
||||||
|
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
||||||
|
}
|
||||||
|
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
|
||||||
|
|
||||||
|
ptr_O[0] = 0;
|
||||||
|
ptr_O[1] = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
||||||
|
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
|
||||||
|
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
|
||||||
|
= (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
|
||||||
|
= (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
|
||||||
|
(float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
||||||
|
(float2 &)ptr_V[idx_SEQ_UNIT * 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
|
||||||
|
|
||||||
|
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]);
|
||||||
|
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]);
|
||||||
|
}
|
||||||
|
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
|
||||||
|
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
|
||||||
|
}
|
||||||
|
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k,
|
||||||
|
float *input_v, int *position_id, float *output_matmul,
|
||||||
|
const AttentionKVCacheMetadata &compMeta) {
|
||||||
|
IT_ASSERT(compMeta.dimSize[3] == 64);
|
||||||
|
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1);
|
||||||
|
dim3 blockDim(BLOCKSIZE, 1);
|
||||||
|
|
||||||
|
_attention_kvcache_kernel<<<gridDim, blockDim>>>(
|
||||||
|
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,55 @@
|
||||||
|
#include "operators/attention_kvcache.h"
|
||||||
|
#include "utils/operator_utils.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
||||||
|
Tensor input_v_cache, Tensor input_q,
|
||||||
|
Tensor input_k, Tensor input_v,
|
||||||
|
Tensor position_id,
|
||||||
|
Tensor output_matmul)
|
||||||
|
: OperatorObj(OpType::AttentionKVCache,
|
||||||
|
TensorVec{input_k_cache, input_v_cache, input_q, input_k,
|
||||||
|
input_v, position_id},
|
||||||
|
{output_matmul}) {
|
||||||
|
int rank = inputs[0]->getRank();
|
||||||
|
IT_ASSERT(rank == 4);
|
||||||
|
dim = 2;
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>>
|
||||||
|
AttentionKVCacheObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
IT_ASSERT(inputs.size() == 6);
|
||||||
|
Shape dims = inputs[0]->getDims();
|
||||||
|
ShapeElem n = dims.at(dim);
|
||||||
|
dims[dim] = n + 1;
|
||||||
|
return {{inputs[2]->getDims()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AttentionKVCacheObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "AttentionKVCache[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
for (auto input : inputs)
|
||||||
|
os << vecToString(input->getDims()) << ",";
|
||||||
|
os << "dim=" << dim << ",";
|
||||||
|
os << "input=";
|
||||||
|
for (auto input : inputs)
|
||||||
|
os << input->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AttentionKVCacheObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = getOutputs()[0]->getDims();
|
||||||
|
ret.emplace(ret.begin(), (int)inputs.size());
|
||||||
|
ret.emplace(ret.begin(), dim);
|
||||||
|
ret.emplace(ret.begin(), type.underlying());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> AttentionKVCacheObj::getOpAttrVector() const {
|
||||||
|
return {type.underlying(), dim};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,42 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/attention_kvcache.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
TEST(AttentionKVCache, Cuda) {
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
|
||||||
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||||
|
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||||
|
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||||
|
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||||
|
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
||||||
|
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
||||||
|
|
||||||
|
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
||||||
|
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
|
||||||
|
position_id_d, nullptr);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
|
||||||
|
input_q_d->setData(OneGenerator());
|
||||||
|
input_k_d->setData(OneGenerator());
|
||||||
|
input_v_d->setData(OneGenerator());
|
||||||
|
position_id_d->setData(IncrementalGenerator());
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
|
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||||
|
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue