forked from jiuyuan/InfiniTensor
[feature] support kvcache with static graph (#209)
* [feature] support kvcache with static graph * use workspace to optimize kvcache attention --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
a5062f3f89
commit
d1a90ba3e2
|
@ -0,0 +1,145 @@
|
||||||
|
import os
|
||||||
|
from pyinfinitensor.onnx import OnnxStub, backend
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
from transformers import LlamaModel, LlamaForCausalLM
|
||||||
|
from tqdm import tqdm
|
||||||
|
import onnx_graphsurgeon as gs
|
||||||
|
from onnxsim import simplify
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('--batchsize', dest='batchsize', type=int, default=1)
|
||||||
|
parser.add_argument('--layer', dest='n_layers', type=int, default=2)
|
||||||
|
parser.add_argument('--iter', dest='n_iter', type=int, default=1)
|
||||||
|
parser.add_argument('--n_max_length', dest='n_max_length', type=int, default=1024)
|
||||||
|
parser.add_argument('--pretrained_llama_path', dest='pretrained_llama_path', type=str,
|
||||||
|
default="/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/")
|
||||||
|
parser.add_argument('--onnx_model_path', dest='onnx_model_path', type=str,
|
||||||
|
default="/data1/shared/llama")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ONNX_MODEL_PATH = "{}/llama_bs{}_layer{}.onnx".format(args.onnx_model_path, args.batchsize, args.n_layers)
|
||||||
|
ONNX_WEIGHT_PATH = "./llama_bs{}_layer{}.pb".format(args.batchsize, args.n_layers)
|
||||||
|
|
||||||
|
def export_onnx(model: LlamaModel, ONNX_MODEL_PATH):
|
||||||
|
param = torch.zeros(
|
||||||
|
(args.batchsize, 1024), dtype=torch.long)
|
||||||
|
logits = model(param, past_key_values=None)
|
||||||
|
param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long)
|
||||||
|
|
||||||
|
torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values,
|
||||||
|
"position_ids": param_kvcache}), ONNX_MODEL_PATH, verbose=False,
|
||||||
|
do_constant_folding=True,)
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_PATH)
|
||||||
|
print("simplifing onnx model")
|
||||||
|
onnx_model, check = simplify(onnx_model, skipped_optimizers=[
|
||||||
|
'eliminate_duplicate_initializer'])
|
||||||
|
assert check
|
||||||
|
|
||||||
|
onnx.save(onnx_model, ONNX_MODEL_PATH, save_as_external_data=True, location=ONNX_WEIGHT_PATH)
|
||||||
|
print("simlifing finished.")
|
||||||
|
|
||||||
|
|
||||||
|
@gs.Graph.register()
|
||||||
|
def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed):
|
||||||
|
for inp in inputs:
|
||||||
|
inp.outputs.clear()
|
||||||
|
for out in outputs:
|
||||||
|
out.inputs.clear()
|
||||||
|
for inp in inputs_added:
|
||||||
|
inputs.append(inp)
|
||||||
|
for out in outputs_removed:
|
||||||
|
out.inputs.clear()
|
||||||
|
return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_onnx_with_attention_op():
|
||||||
|
graph = gs.import_onnx(
|
||||||
|
onnx.load(ONNX_MODEL_PATH))
|
||||||
|
tmap = graph.tensors()
|
||||||
|
for i in range(args.n_layers):
|
||||||
|
inputs = [
|
||||||
|
tmap["onnx::Concat_" + str((i+1)*2)],
|
||||||
|
tmap["onnx::Concat_" + str((i+1)*2+1)],
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"],
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"],
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]]
|
||||||
|
outputs = [
|
||||||
|
tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]]
|
||||||
|
|
||||||
|
inputs_added = [graph.inputs[1]]
|
||||||
|
outputs_removed = []
|
||||||
|
|
||||||
|
graph.replace_with_attention(
|
||||||
|
inputs, outputs, inputs_added, outputs_removed)
|
||||||
|
|
||||||
|
graph.outputs = [tmap[graph.outputs[0].name]]
|
||||||
|
graph.cleanup(True).toposort()
|
||||||
|
onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
kvcache_torch = None
|
||||||
|
torch_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
args.pretrained_llama_path, num_hidden_layers=int(args.n_layers)).eval()
|
||||||
|
|
||||||
|
n_heads = torch_model.config.num_attention_heads
|
||||||
|
n_dims = torch_model.config.hidden_size // n_heads
|
||||||
|
|
||||||
|
if not os.path.exists(ONNX_MODEL_PATH):
|
||||||
|
print("exporting onnx graph")
|
||||||
|
export_onnx(torch_model, ONNX_MODEL_PATH)
|
||||||
|
replace_onnx_with_attention_op()
|
||||||
|
else:
|
||||||
|
print("will use exsiting onnx graph")
|
||||||
|
|
||||||
|
onnx_model = onnx.load(ONNX_MODEL_PATH)
|
||||||
|
stub = OnnxStub(onnx_model, backend.cuda_runtime())
|
||||||
|
|
||||||
|
count_wrong = 0
|
||||||
|
for i in tqdm(range(0, args.n_max_length)):
|
||||||
|
query = np.random.randint(
|
||||||
|
torch_model.config.vocab_size, size=(args.batchsize, 1), dtype=np.int32)
|
||||||
|
position_id = i*np.ones((args.batchsize, 1), dtype=np.int32)
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# pytorch
|
||||||
|
####################################
|
||||||
|
outputs_torch = torch_model(
|
||||||
|
torch.tensor(query), past_key_values=kvcache_torch)
|
||||||
|
logit_torch = outputs_torch['logits']
|
||||||
|
kvcache_torch = outputs_torch['past_key_values']
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# infinitensor
|
||||||
|
####################################
|
||||||
|
# copyin input
|
||||||
|
(list(stub.inputs.items()))[0][1].copyin_int64(
|
||||||
|
query.reshape(-1).tolist())
|
||||||
|
(list(stub.inputs.items()))[1][1].copyin_int64(
|
||||||
|
position_id.reshape(-1).tolist())
|
||||||
|
|
||||||
|
stub.run()
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# validation
|
||||||
|
####################################
|
||||||
|
# copyout output
|
||||||
|
logits_it = np.array((list(stub.outputs.items()))
|
||||||
|
[0][1].copyout_float())
|
||||||
|
|
||||||
|
try:
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
logit_torch[:, -1, :].detach().cpu().numpy().flatten(), logits_it, rtol=1e-3, atol=1e-3)
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.argmax(logit_torch[:, -1, :].detach().cpu().numpy().flatten()), np.argmax(logits_it), rtol=1e-3, atol=1e-3)
|
||||||
|
except:
|
||||||
|
count_wrong = count_wrong + 1
|
||||||
|
|
||||||
|
result = "{}/{} failed.".format(count_wrong, args.n_max_length)
|
||||||
|
print(result)
|
||||||
|
del stub
|
|
@ -1,4 +1,5 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
struct AttentionKVCacheMetadata {
|
struct AttentionKVCacheMetadata {
|
||||||
|
@ -10,6 +11,7 @@ namespace infini {
|
||||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||||
float *input_q, float *input_k, float *input_v,
|
float *input_q, float *input_k, float *input_v,
|
||||||
int *position_id, float *output_matmul,
|
int *position_id, float *output_matmul,
|
||||||
const AttentionKVCacheMetadata &compMeta);
|
const AttentionKVCacheMetadata &compMeta,
|
||||||
|
float *output_O_temp, float *output_sum_temp);
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -334,7 +334,7 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
||||||
std::move(input_k_cache), std::move(input_v_cache),
|
std::move(input_k_cache), std::move(input_v_cache),
|
||||||
std::move(input_q), std::move(input_k), std::move(input_v),
|
std::move(input_q), std::move(input_k), std::move(input_v),
|
||||||
std::move(position_id), output_matmul);
|
std::move(position_id), output_matmul);
|
||||||
return {output_matmul};
|
return output_matmul;
|
||||||
} else {
|
} else {
|
||||||
return g
|
return g
|
||||||
->addOp<AttentionKVCacheObj>(
|
->addOp<AttentionKVCacheObj>(
|
||||||
|
|
|
@ -21,7 +21,7 @@ class AttentionKVCacheCompute {
|
||||||
public:
|
public:
|
||||||
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q,
|
||||||
Tensor input_k, Tensor input_v, Tensor position_id,
|
Tensor input_k, Tensor input_v, Tensor position_id,
|
||||||
Tensor output_matmul) const {
|
Tensor output_matmul, CudaPtr p_workspace) const {
|
||||||
AttentionKVCacheMetadata metadata;
|
AttentionKVCacheMetadata metadata;
|
||||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ class AttentionKVCacheCompute {
|
||||||
input_v->getRawDataPtr<float *>(),
|
input_v->getRawDataPtr<float *>(),
|
||||||
position_id->getRawDataPtr<int *>(),
|
position_id->getRawDataPtr<int *>(),
|
||||||
output_matmul->getRawDataPtr<float *>(),
|
output_matmul->getRawDataPtr<float *>(),
|
||||||
metadata);
|
metadata, (float *)p_workspace,
|
||||||
|
(float *)(p_workspace + (1ll << 30)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -41,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute,
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
IT_ASSERT(_op->getDType() == DataType::Float32);
|
IT_ASSERT(_op->getDType() == DataType::Float32);
|
||||||
|
|
||||||
|
size_t workspaceSize = 2ll << 30;
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
CudaPtr idxWsData = context->getWorkspace(workspaceSize);
|
||||||
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
do_compute(_op->getInputs()[0], _op->getInputs()[1],
|
||||||
_op->getInputs()[2], _op->getInputs()[3],
|
_op->getInputs()[2], _op->getInputs()[3],
|
||||||
_op->getInputs()[4], _op->getInputs()[5],
|
_op->getInputs()[4], _op->getInputs()[5],
|
||||||
_op->getOutputs()[0]);
|
_op->getOutputs()[0], idxWsData);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -2,127 +2,168 @@
|
||||||
#include "cuda/cuda_attention_kvcache.h"
|
#include "cuda/cuda_attention_kvcache.h"
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
#define BLOCKSIZE WARP_SIZE
|
#define BLOCKSIZE WARP_SIZE
|
||||||
#define SEQ_UNIT 64
|
#define SEQ_UNIT 32
|
||||||
|
|
||||||
__global__ void _attention_kvcache_kernel(float* input_k_cache,
|
// ASSUME SEQ_LEN OF Q IS 1
|
||||||
|
__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache,
|
||||||
float* input_v_cache,
|
float* input_v_cache,
|
||||||
float* input_q,
|
float* input_q,
|
||||||
float* input_k,
|
float* input_k,
|
||||||
float* input_v,
|
float* input_v,
|
||||||
int* position_id,
|
int* position_id,
|
||||||
float* output_matmul,
|
AttentionKVCacheMetadata compMeta,
|
||||||
AttentionKVCacheMetadata compMeta) {
|
float* output_O_temp,
|
||||||
|
float* output_sum_temp) {
|
||||||
|
int seq_length = position_id[0] + 1;
|
||||||
|
int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT;
|
||||||
|
if(blockIdx.y >= stride)
|
||||||
|
return;
|
||||||
|
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
int group_id = threadIdx.x / WARP_SIZE;
|
int group_id = threadIdx.x / WARP_SIZE;
|
||||||
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||||
|
int idx_seq = blockIdx.y * SEQ_UNIT;
|
||||||
|
|
||||||
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1])
|
||||||
return;
|
return;
|
||||||
|
|
||||||
float ptr_V[SEQ_UNIT*2];
|
float ptr_V[SEQ_UNIT*4]; // V
|
||||||
float ptr_K[SEQ_UNIT*2];
|
float ptr_K[SEQ_UNIT*4]; // K
|
||||||
float ptr_Q[2];
|
float ptr_Q[4]; // Q
|
||||||
float ptr_P[SEQ_UNIT];
|
float ptr_P[SEQ_UNIT] = {0};
|
||||||
|
|
||||||
float ptr_O[2];
|
float ptr_O[4] = {0};
|
||||||
float ptr_max[1];
|
float ptr_sum[1] = {0};
|
||||||
float ptr_sum[1];
|
|
||||||
|
|
||||||
float ptr_max_last[1];
|
// readin Q
|
||||||
float ptr_sum_last[1];
|
(float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)];
|
||||||
float ptr_O_last[2];
|
int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]);
|
||||||
|
|
||||||
(float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)];
|
// Q*K
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||||
|
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||||
|
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
||||||
|
= (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
(float4 &)ptr_K[idx_SEQ_UNIT * 4]
|
||||||
|
= (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
||||||
|
(float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
||||||
|
(float4 &)ptr_K[idx_SEQ_UNIT * 4];
|
||||||
|
}
|
||||||
|
|
||||||
int SEQ_LENGTH = position_id[0] + 1;
|
|
||||||
|
|
||||||
int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]);
|
|
||||||
|
|
||||||
|
|
||||||
for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){
|
|
||||||
ptr_max_last[0] = ptr_max[0];
|
|
||||||
ptr_sum_last[0] = ptr_sum[0];
|
|
||||||
(float2 &)ptr_O_last[0] = (float2 &)ptr_O[0];
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
for (int i = 0; i < 4; i ++){
|
||||||
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
|
ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i];
|
||||||
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
|
|
||||||
= (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
(float2 &)ptr_K[idx_SEQ_UNIT * 2]
|
|
||||||
= (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
|
|
||||||
(float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
|
||||||
(float2 &)ptr_K[idx_SEQ_UNIT * 2];
|
|
||||||
}
|
|
||||||
ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2];
|
|
||||||
ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1];
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset /= 2) {
|
for (int offset = 16; offset > 0; offset /= 2) {
|
||||||
ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset);
|
ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset);
|
||||||
}
|
}
|
||||||
ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2];
|
ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i];
|
||||||
#pragma unroll
|
|
||||||
for (int offset = 16; offset > 0; offset /= 2){
|
|
||||||
ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset);
|
|
||||||
}
|
|
||||||
ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
|
||||||
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
|
||||||
ptr_P[idx_SEQ_UNIT] /= 8;
|
|
||||||
ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]);
|
|
||||||
}
|
|
||||||
ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]);
|
|
||||||
|
|
||||||
ptr_sum[0] = 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
|
||||||
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]);
|
|
||||||
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
|
||||||
}
|
|
||||||
ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0];
|
|
||||||
|
|
||||||
ptr_O[0] = 0;
|
|
||||||
ptr_O[1] = 0;
|
|
||||||
#pragma unroll
|
|
||||||
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) {
|
|
||||||
if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){
|
|
||||||
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
|
|
||||||
= (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
(float2 &)ptr_V[idx_SEQ_UNIT * 2]
|
|
||||||
= (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])];
|
|
||||||
(float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] =
|
|
||||||
(float2 &)ptr_V[idx_SEQ_UNIT * 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
ptr_P[idx_SEQ_UNIT] /= ptr_sum[0];
|
|
||||||
|
|
||||||
ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]);
|
|
||||||
ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]);
|
|
||||||
}
|
|
||||||
ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
|
|
||||||
ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0];
|
|
||||||
}
|
}
|
||||||
(float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0];
|
|
||||||
|
// div sqrt(d)
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||||
|
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
||||||
|
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// softmax
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||||
|
ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]);
|
||||||
|
ptr_sum[0] += ptr_P[idx_SEQ_UNIT];
|
||||||
|
}
|
||||||
|
|
||||||
|
// * V
|
||||||
|
#pragma unroll
|
||||||
|
for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) {
|
||||||
|
if(idx_SEQ_UNIT + idx_seq < seq_length - 1){
|
||||||
|
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
||||||
|
= (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])];
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
(float4 &)ptr_V[idx_SEQ_UNIT * 4]
|
||||||
|
= (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])];
|
||||||
|
(float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]
|
||||||
|
= (float4 &)ptr_V[idx_SEQ_UNIT * 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i ++)
|
||||||
|
ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i ++)
|
||||||
|
ptr_O[i] /= ptr_sum[0];
|
||||||
|
|
||||||
|
(float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0];
|
||||||
|
if(threadIdx.x == 0){
|
||||||
|
output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0];
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void _attention_kvcache_kernel_128_2(int* position_id,
|
||||||
|
float* output_matmul,
|
||||||
|
AttentionKVCacheMetadata compMeta,
|
||||||
|
float* output_O_temp,
|
||||||
|
float* output_sum_temp) {
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
int group_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id;
|
||||||
|
|
||||||
|
float ptr_O[4] = {0};
|
||||||
|
float ptr_O_sum[4] = {0};
|
||||||
|
float ptr_sum = 0;
|
||||||
|
float ptr_sum_temp;
|
||||||
|
int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for(int i = 0; i < size; i ++){
|
||||||
|
(float4 &)ptr_O[0]
|
||||||
|
= (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size];
|
||||||
|
ptr_sum_temp = output_sum_temp[i + parallel_idx * size];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for(int k = 0; k < 4; k ++)
|
||||||
|
ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp;
|
||||||
|
ptr_sum += ptr_sum_temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for(int k = 0; k < 4; k ++)
|
||||||
|
ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum;
|
||||||
|
|
||||||
|
(float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0];
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k,
|
void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache,
|
||||||
|
float *input_q, float *input_k,
|
||||||
float *input_v, int *position_id, float *output_matmul,
|
float *input_v, int *position_id, float *output_matmul,
|
||||||
const AttentionKVCacheMetadata &compMeta) {
|
const AttentionKVCacheMetadata &compMeta,
|
||||||
IT_ASSERT(compMeta.dimSize[3] == 64);
|
float *output_O_temp, float *output_sum_temp) {
|
||||||
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1);
|
IT_ASSERT(compMeta.dimSize[3] == 128);
|
||||||
|
|
||||||
|
int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT;
|
||||||
|
dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y);
|
||||||
dim3 blockDim(BLOCKSIZE, 1);
|
dim3 blockDim(BLOCKSIZE, 1);
|
||||||
|
|
||||||
_attention_kvcache_kernel<<<gridDim, blockDim>>>(
|
assert(compMeta.dimSize[3] == 128);
|
||||||
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta);
|
_attention_kvcache_kernel_128_1<<<gridDim, blockDim>>>(
|
||||||
|
input_k_cache, input_v_cache, input_q, input_k, input_v, position_id,
|
||||||
|
compMeta, output_O_temp, output_sum_temp);
|
||||||
|
_attention_kvcache_kernel_128_2<<<compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), WARP_SIZE>>>(
|
||||||
|
position_id, output_matmul, compMeta, output_O_temp, output_sum_temp);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -14,11 +14,11 @@ TEST(AttentionKVCache, Cuda) {
|
||||||
|
|
||||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||||
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||||
auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||||
auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||||
auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32);
|
auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32);
|
||||||
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
||||||
|
|
||||||
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
auto op = gCuda->addOp<AttentionKVCacheObj>(
|
||||||
|
@ -32,11 +32,14 @@ TEST(AttentionKVCache, Cuda) {
|
||||||
position_id_d->setData(IncrementalGenerator());
|
position_id_d->setData(IncrementalGenerator());
|
||||||
cudaRuntime->run(gCuda);
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
|
||||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
Loading…
Reference in New Issue