[feature] support kvcache with static graph

This commit is contained in:
xiaonans 2024-01-17 11:26:05 +08:00
parent 51086d2b8d
commit 6a1bfd6c45
11 changed files with 335 additions and 133 deletions

View File

@ -0,0 +1,146 @@
import os
from pyinfinitensor.onnx import OnnxStub, backend
import numpy as np
import onnx
import torch
from transformers import LlamaModel, LlamaForCausalLM
from tqdm import tqdm
import onnx_graphsurgeon as gs
from onnxsim import simplify
import argparse
parser = argparse.ArgumentParser(description='')
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
parser.add_argument('--layer', dest='n_layers', type=int, default=2)
parser.add_argument('--iter', dest='n_iter', type=int, default=1)
parser.add_argument('--n_max_length', dest='n_max_length', type=int, default=1024)
parser.add_argument('--pretrained_llama_path', dest='pretrained_llama_path', type=str,
default="/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/")
parser.add_argument('--onnx_model_path', dest='onnx_model_path', type=str,
default="/data1/shared/llama")
args = parser.parse_args()
ONNX_MODEL_PATH = "{}/llama_bs{}_layer{}.onnx".format(args.onnx_model_path, args.batchsize, args.n_layers)
ONNX_WEIGHT_PATH = "./llama_bs{}_layer{}.pb".format(args.batchsize, args.n_layers)
def export_onnx(model: LlamaModel, ONNX_MODEL_PATH):
param = torch.zeros(
(args.batchsize, 1024), dtype=torch.long)
logits = model(param, past_key_values=None)
param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long)
torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values,
"position_ids": param_kvcache}), ONNX_MODEL_PATH, verbose=False,
do_constant_folding=True,)
onnx_model = onnx.load(ONNX_MODEL_PATH)
print("simplifing onnx model")
onnx_model, check = simplify(onnx_model, skipped_optimizers=[
'eliminate_duplicate_initializer'])
assert check
onnx.save(onnx_model, ONNX_MODEL_PATH, save_as_external_data=True, location=ONNX_WEIGHT_PATH)
print("simlifing finished.")
@gs.Graph.register()
def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed):
for inp in inputs:
inp.outputs.clear()
for out in outputs:
out.inputs.clear()
for inp in inputs_added:
inputs.append(inp)
for out in outputs_removed:
out.inputs.clear()
return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs)
def replace_onnx_with_attention_op():
graph = gs.import_onnx(
onnx.load(ONNX_MODEL_PATH))
tmap = graph.tensors()
for i in range(args.n_layers):
inputs = [
tmap["onnx::Concat_" + str((i+1)*2)],
tmap["onnx::Concat_" + str((i+1)*2+1)],
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
outputs = [
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"],
tmap[graph.outputs[1+i*2].name],
tmap[graph.outputs[2+i*2].name]]
inputs_added = [graph.inputs[1]]
outputs_removed = []
graph.replace_with_attention(
inputs, outputs, inputs_added, outputs_removed)
graph.cleanup(True).toposort()
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True)
if __name__ == "__main__":
kvcache_torch = None
torch_model = LlamaForCausalLM.from_pretrained(
args.pretrained_llama_path, num_hidden_layers=int(args.n_layers)).eval()
n_heads = torch_model.config.num_attention_heads
n_dims = torch_model.config.hidden_size // n_heads
if not os.path.exists(ONNX_MODEL_PATH):
print("exporting onnx graph")
export_onnx(torch_model, ONNX_MODEL_PATH)
replace_onnx_with_attention_op()
else:
print("will use exsiting onnx graph")
onnx_model = onnx.load(ONNX_MODEL_PATH)
stub = OnnxStub(onnx_model, backend.cuda_runtime())
count_wrong = 0
for i in tqdm(range(0, args.n_max_length)):
query = np.random.randint(
torch_model.config.vocab_size, size=(args.batchsize, 1), dtype=np.int32)
position_id = i*np.ones((args.batchsize, 1), dtype=np.int32)
####################################
# pytorch
####################################
outputs_torch = torch_model(
torch.tensor(query), past_key_values=kvcache_torch)
logit_torch = outputs_torch['logits']
kvcache_torch = outputs_torch['past_key_values']
####################################
# infinitensor
####################################
# copyin input
(list(stub.inputs.items()))[0][1].copyin_int64(
query.reshape(-1).tolist())
(list(stub.inputs.items()))[1][1].copyin_int64(
position_id.reshape(-1).tolist())
stub.run()
####################################
# validation
####################################
# copyout output
logits_it = np.array((list(stub.outputs.items()))
[0][1].copyout_float())
try:
np.testing.assert_allclose(
logit_torch[:, -1, :].detach().cpu().numpy().flatten(), logits_it, rtol=1e-3, atol=1e-3)
except Exception as e:
try:
np.testing.assert_allclose(
np.argmax(logit_torch[:, -1, :].detach().cpu().numpy().flatten()), np.argmax(logits_it), rtol=1e-3, atol=1e-3)
except:
count_wrong = count_wrong + 1
result = "{}/{} failed.".format(count_wrong, args.n_max_length)
print(result)
del stub

View File

@ -74,9 +74,10 @@ class GraphHandlerObj {
Tensor squeeze(Tensor input, Tensor output, Shape axes); Tensor squeeze(Tensor input, Tensor output, Shape axes);
Tensor unsqueeze(Tensor input, Tensor output, Shape axes); Tensor unsqueeze(Tensor input, Tensor output, Shape axes);
Tensor concat(TensorVec inputs, Tensor output, int dim); Tensor concat(TensorVec inputs, Tensor output, int dim);
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v, Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul); Tensor position_id, Tensor output_matmul,
Tensor output_k_cache, Tensor output_v_cache);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis, TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
std::variant<int, vector<int>> numOrRatio); std::variant<int, vector<int>> numOrRatio);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);

View File

@ -10,6 +10,7 @@ namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k, float *input_v, float *input_q, float *input_k, float *input_v,
int *position_id, float *output_matmul, int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta); const AttentionKVCacheMetadata &compMeta,
float *output_O_temp, float *output_sum_temp);
} // namespace infini } // namespace infini

View File

@ -22,18 +22,21 @@ class AttentionKVCacheObj : public OperatorObj {
* @param input_v The value input tensor. * @param input_v The value input tensor.
* @param position_id The positon id of the query, * @param position_id The positon id of the query,
* @param output_matmul The query output tensor. * @param output_matmul The query output tensor.
* @param output_k_cache The output k_cache tensor.
* @param output_v_cache The output v_cache tensor.
*/ */
AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_v, Tensor position_id, Tensor input_v, Tensor position_id,
Tensor output_matmul); Tensor output_matmul, Tensor output_k_cache,
Tensor output_v_cache);
OP_CLONE(AttentionKVCacheObj); OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 6; } int numInputs() const override { return 6; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 3; }
int getDim() const { return dim; } int getDim() const { return dim; }
private: private:

View File

@ -660,7 +660,7 @@ class OnnxStub:
next((attr.i for attr in node.attribute if attr.name == "axis")), next((attr.i for attr in node.attribute if attr.name == "axis")),
) )
elif node.op_type == "AttentionKVCache": elif node.op_type == "AttentionKVCache":
tensors[node.output[0]] = self.handler.attentionKVCache( tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache(
tensors[node.input[0]], tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
tensors[node.input[2]], tensors[node.input[2]],
@ -668,6 +668,8 @@ class OnnxStub:
tensors[node.input[4]], tensors[node.input[4]],
tensors[node.input[5]], tensors[node.input[5]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
tensors.get(node.output[1]),
tensors.get(node.output[2]),
) )
elif node.op_type == "Split": elif node.op_type == "Split":
split = ( split = (

View File

@ -324,24 +324,25 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
} }
} }
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, TensorVec GraphHandlerObj::attentionKVCache(
Tensor input_v_cache, Tensor input_q, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k,
Tensor input_k, Tensor input_v, Tensor input_v, Tensor position_id, Tensor output_matmul,
Tensor position_id, Tensor output_k_cache, Tensor output_v_cache) {
Tensor output_matmul) { if (output_matmul && output_k_cache && output_v_cache) {
if (output_matmul) {
g->addOpWithOutputs<AttentionKVCacheObj>( g->addOpWithOutputs<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache), std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v), std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul); std::move(position_id), output_matmul, output_k_cache,
return {output_matmul}; output_v_cache);
return {output_matmul, output_k_cache, output_v_cache};
} else { } else {
return g return g
->addOp<AttentionKVCacheObj>( ->addOp<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache), std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v), std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul) std::move(position_id), output_matmul, output_k_cache,
->getOutput(); output_v_cache)
->getOutputs();
} }
} }

View File

@ -67,8 +67,10 @@ bool OperatorObj::checkValid(GraphObj *graph) {
if (graph) { // if graph != nullptr, outputs should be created if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = inferDataType(); auto dataTypes = inferDataType();
for (size_t i = 0; i < outputs.size(); i++) { for (size_t i = 0; i < outputs.size(); i++) {
IT_ASSERT(!outputs[i], "Find empty output while operator creation"); if (!outputs[i])
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
else if (shapes[i] != outputs[i]->getDims())
return false;
} }
} else { // if outputs have been created, check their shapes } else { // if outputs have been created, check their shapes
for (size_t i = 0; i < shapes.size(); ++i) { for (size_t i = 0; i < shapes.size(); ++i) {

View File

@ -21,7 +21,8 @@ class AttentionKVCacheCompute {
public: public:
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor position_id, Tensor input_k, Tensor input_v, Tensor position_id,
Tensor output_matmul) const { Tensor output_matmul, Tensor output_temp_O,
Tensor output_temp_sum) const {
AttentionKVCacheMetadata metadata; AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache); initAttentionKVCacheMetadata(metadata, input_v_cache);
@ -32,7 +33,9 @@ class AttentionKVCacheCompute {
input_v->getRawDataPtr<float *>(), input_v->getRawDataPtr<float *>(),
position_id->getRawDataPtr<int *>(), position_id->getRawDataPtr<int *>(),
output_matmul->getRawDataPtr<float *>(), output_matmul->getRawDataPtr<float *>(),
metadata); metadata,
output_temp_O->getRawDataPtr<float *>(),
output_temp_sum->getRawDataPtr<float *>());
} }
}; };
@ -41,10 +44,10 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
IT_ASSERT(_op->getDType() == DataType::Float32); IT_ASSERT(_op->getDType() == DataType::Float32);
do_compute(_op->getInputs()[0], _op->getInputs()[1], do_compute(
_op->getInputs()[2], _op->getInputs()[3], _op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2],
_op->getInputs()[4], _op->getInputs()[5], _op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5],
_op->getOutputs()[0]); _op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]);
} }
}; };

View File

@ -2,127 +2,168 @@
#include "cuda/cuda_attention_kvcache.h" #include "cuda/cuda_attention_kvcache.h"
#define WARP_SIZE 32 #define WARP_SIZE 32
#define BLOCKSIZE WARP_SIZE #define BLOCKSIZE WARP_SIZE
#define SEQ_UNIT 64 #define SEQ_UNIT 32
__global__ void _attention_kvcache_kernel(float* input_k_cache, // ASSUME SEQ_LEN OF Q IS 1
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
float* input_v_cache, float* input_v_cache,
float* input_q, float* input_q,
float* input_k, float* input_k,
float* input_v, float* input_v,
int* position_id, int* position_id,
float* output_matmul, AttentionKVCacheMetadata compMeta,
AttentionKVCacheMetadata compMeta) { float* output_O_temp,
float* output_sum_temp) {
int seq_length = position_id[0] + 1;
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
if(blockIdx.y >= stride)
return;
int lane_id = threadIdx.x % WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE; int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
int idx_seq = blockIdx.y * SEQ_UNIT;
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
return; return;
float ptr_V[SEQ_UNIT*2]; float ptr_V[SEQ_UNIT*4]; // V
float ptr_K[SEQ_UNIT*2]; float ptr_K[SEQ_UNIT*4]; // K
float ptr_Q[2]; float ptr_Q[4]; // Q
float ptr_P[SEQ_UNIT]; float ptr_P[SEQ_UNIT] = {0};
float ptr_O[2]; float ptr_O[4] = {0};
float ptr_max[1]; float ptr_sum[1] = {0};
float ptr_sum[1];
float ptr_max_last[1]; // readin Q
float ptr_sum_last[1]; (float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
float ptr_O_last[2]; int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
(float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)];
int SEQ_LENGTH = position_id[0] + 1;
int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]);
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
ptr_max_last[0] = ptr_max[0];
ptr_sum_last[0] = ptr_sum[0];
(float2 &)ptr_O_last[0] = (float2 &)ptr_O[0];
// Q*K
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float2 &)ptr_K[idx_SEQ_UNIT * 2] (float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; = (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
} }
else{ else{
(float2 &)ptr_K[idx_SEQ_UNIT * 2] (float4 &)ptr_K[idx_SEQ_UNIT * 4]
= (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; = (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = (float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
(float2 &)ptr_K[idx_SEQ_UNIT * 2]; (float4 &)ptr_K[idx_SEQ_UNIT * 4];
}
ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2];
ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset);
}
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2){
ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)];
} }
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int i = 0; i < 4; i ++){
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
}
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
}
}
// div sqrt(d)
#pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= 8; ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
} }
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
ptr_sum[0] = 0; // softmax
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]); ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
} }
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
ptr_O[0] = 0; // * V
ptr_O[1] = 0;
#pragma unroll #pragma unroll
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
(float2 &)ptr_V[idx_SEQ_UNIT * 2] (float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; = (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
} }
else{ else{
(float2 &)ptr_V[idx_SEQ_UNIT * 2] (float4 &)ptr_V[idx_SEQ_UNIT * 4]
= (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; = (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
(float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = (float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
(float2 &)ptr_V[idx_SEQ_UNIT * 2]; = (float4 &)ptr_V[idx_SEQ_UNIT * 4];
} }
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0]; #pragma unroll
for (int i = 0; i < 4; i ++)
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
}
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]); #pragma unroll
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]); for (int i = 0; i < 4; i ++)
ptr_O[i] /= ptr_sum[0];
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
if(threadIdx.x == 0){
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
} }
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
} }
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
float* output_matmul,
AttentionKVCacheMetadata compMeta,
float* output_O_temp,
float* output_sum_temp) {
int lane_id = threadIdx.x % WARP_SIZE;
int group_id = threadIdx.x / WARP_SIZE;
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
float ptr_O[4] = {0};
float ptr_O_sum[4] = {0};
float ptr_sum = 0;
float ptr_sum_temp;
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
#pragma unroll
for(int i = 0; i < size; i ++){
(float4 &)ptr_O[0]
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
ptr_sum += ptr_sum_temp;
} }
#pragma unroll
for(int k = 0; k < 4; k ++)
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
}
namespace infini { namespace infini {
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
float *input_q, float *input_k,
float *input_v, int *position_id, float *output_matmul, float *input_v, int *position_id, float *output_matmul,
const AttentionKVCacheMetadata &compMeta) { const AttentionKVCacheMetadata &compMeta,
IT_ASSERT(compMeta.dimSize[3] == 64); float *output_O_temp, float *output_sum_temp) {
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1); IT_ASSERT(compMeta.dimSize[3] == 128);
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
dim3 blockDim(BLOCKSIZE, 1); dim3 blockDim(BLOCKSIZE, 1);
_attention_kvcache_kernel<<<gridDim, blockDim>>>( assert(compMeta.dimSize[3] == 128);
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta); _attention_kvcache_kernel_128_1<<<gridDim, blockDim>>>(
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
compMeta, output_O_temp, output_sum_temp);
_attention_kvcache_kernel_128_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE>>>(
position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
} }
} // namespace infini } // namespace infini

View File

@ -2,15 +2,14 @@
#include "utils/operator_utils.h" #include "utils/operator_utils.h"
namespace infini { namespace infini {
AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, AttentionKVCacheObj::AttentionKVCacheObj(
Tensor input_v_cache, Tensor input_q, GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v, Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul,
Tensor position_id, Tensor output_k_cache, Tensor output_v_cache)
Tensor output_matmul)
: OperatorObj(OpType::AttentionKVCache, : OperatorObj(OpType::AttentionKVCache,
TensorVec{input_k_cache, input_v_cache, input_q, input_k, TensorVec{input_k_cache, input_v_cache, input_q, input_k,
input_v, position_id}, input_v, position_id},
{output_matmul}) { TensorVec{output_matmul, output_k_cache, output_v_cache}) {
int rank = inputs[0]->getRank(); int rank = inputs[0]->getRank();
IT_ASSERT(rank == 4); IT_ASSERT(rank == 4);
dim = 2; dim = 2;
@ -23,7 +22,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) {
Shape dims = inputs[0]->getDims(); Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim); ShapeElem n = dims.at(dim);
dims[dim] = n + 1; dims[dim] = n + 1;
return {{inputs[2]->getDims()}}; return {{inputs[2]->getDims(), dims, dims}};
} }
std::string AttentionKVCacheObj::toString() const { std::string AttentionKVCacheObj::toString() const {

View File

@ -14,16 +14,16 @@ TEST(AttentionKVCache, Cuda) {
auto cudaRuntime = make_ref<CudaRuntimeObj>(); auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime); Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
auto op = gCuda->addOp<AttentionKVCacheObj>( auto op = gCuda->addOp<AttentionKVCacheObj>(
input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d, input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d,
position_id_d, nullptr); position_id_d, nullptr, nullptr, nullptr);
gCuda->dataMalloc(); gCuda->dataMalloc();
input_q_d->setData(OneGenerator()); input_q_d->setData(OneGenerator());
@ -32,11 +32,14 @@ TEST(AttentionKVCache, Cuda) {
position_id_d->setData(IncrementalGenerator()); position_id_d->setData(IncrementalGenerator());
cudaRuntime->run(gCuda); cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput()); auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
EXPECT_TRUE(oCpu->equalData(vector<float>{ EXPECT_TRUE(oCpu->equalData(vector<float>{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
} }
} // namespace infini } // namespace infini