From d1a90ba3e22906b1e7b0160acc7ec4a0eff639b1 Mon Sep 17 00:00:00 2001 From: xiaonans <51065160+xiaonans@users.noreply.github.com> Date: Thu, 25 Jan 2024 14:20:43 +0800 Subject: [PATCH] [feature] support kvcache with static graph (#209) * [feature] support kvcache with static graph * use workspace to optimize kvcache attention --------- Co-authored-by: Haojie Wang --- examples/python/llama_kvcache_inference.py | 145 +++++++++++++ include/cuda/cuda_attention_kvcache.h | 4 +- src/core/graph_handler.cc | 2 +- src/kernels/cuda/attention_kvcache.cc | 11 +- src/kernels/cuda/attention_kvcache.cu | 227 ++++++++++++--------- test/kernels/cuda/test_cuda_attention.cc | 17 +- 6 files changed, 301 insertions(+), 105 deletions(-) create mode 100644 examples/python/llama_kvcache_inference.py diff --git a/examples/python/llama_kvcache_inference.py b/examples/python/llama_kvcache_inference.py new file mode 100644 index 00000000..e6ba67ff --- /dev/null +++ b/examples/python/llama_kvcache_inference.py @@ -0,0 +1,145 @@ +import os +from pyinfinitensor.onnx import OnnxStub, backend +import numpy as np +import onnx +import torch +from transformers import LlamaModel, LlamaForCausalLM +from tqdm import tqdm +import onnx_graphsurgeon as gs +from onnxsim import simplify +import argparse + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--batchsize', dest='batchsize', type=int, default=1) +parser.add_argument('--layer', dest='n_layers', type=int, default=2) +parser.add_argument('--iter', dest='n_iter', type=int, default=1) +parser.add_argument('--n_max_length', dest='n_max_length', type=int, default=1024) +parser.add_argument('--pretrained_llama_path', dest='pretrained_llama_path', type=str, + default="/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/") +parser.add_argument('--onnx_model_path', dest='onnx_model_path', type=str, + default="/data1/shared/llama") +args = parser.parse_args() + +ONNX_MODEL_PATH = "{}/llama_bs{}_layer{}.onnx".format(args.onnx_model_path, args.batchsize, args.n_layers) +ONNX_WEIGHT_PATH = "./llama_bs{}_layer{}.pb".format(args.batchsize, args.n_layers) + +def export_onnx(model: LlamaModel, ONNX_MODEL_PATH): + param = torch.zeros( + (args.batchsize, 1024), dtype=torch.long) + logits = model(param, past_key_values=None) + param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long) + + torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values, + "position_ids": param_kvcache}), ONNX_MODEL_PATH, verbose=False, + do_constant_folding=True,) + onnx_model = onnx.load(ONNX_MODEL_PATH) + print("simplifing onnx model") + onnx_model, check = simplify(onnx_model, skipped_optimizers=[ + 'eliminate_duplicate_initializer']) + assert check + + onnx.save(onnx_model, ONNX_MODEL_PATH, save_as_external_data=True, location=ONNX_WEIGHT_PATH) + print("simlifing finished.") + + +@gs.Graph.register() +def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed): + for inp in inputs: + inp.outputs.clear() + for out in outputs: + out.inputs.clear() + for inp in inputs_added: + inputs.append(inp) + for out in outputs_removed: + out.inputs.clear() + return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs) + + +def replace_onnx_with_attention_op(): + graph = gs.import_onnx( + onnx.load(ONNX_MODEL_PATH)) + tmap = graph.tensors() + for i in range(args.n_layers): + inputs = [ + tmap["onnx::Concat_" + str((i+1)*2)], + tmap["onnx::Concat_" + str((i+1)*2+1)], + tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]] + outputs = [ + tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]] + + inputs_added = [graph.inputs[1]] + outputs_removed = [] + + graph.replace_with_attention( + inputs, outputs, inputs_added, outputs_removed) + + graph.outputs = [tmap[graph.outputs[0].name]] + graph.cleanup(True).toposort() + onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True) + + +if __name__ == "__main__": + kvcache_torch = None + torch_model = LlamaForCausalLM.from_pretrained( + args.pretrained_llama_path, num_hidden_layers=int(args.n_layers)).eval() + + n_heads = torch_model.config.num_attention_heads + n_dims = torch_model.config.hidden_size // n_heads + + if not os.path.exists(ONNX_MODEL_PATH): + print("exporting onnx graph") + export_onnx(torch_model, ONNX_MODEL_PATH) + replace_onnx_with_attention_op() + else: + print("will use exsiting onnx graph") + + onnx_model = onnx.load(ONNX_MODEL_PATH) + stub = OnnxStub(onnx_model, backend.cuda_runtime()) + + count_wrong = 0 + for i in tqdm(range(0, args.n_max_length)): + query = np.random.randint( + torch_model.config.vocab_size, size=(args.batchsize, 1), dtype=np.int32) + position_id = i*np.ones((args.batchsize, 1), dtype=np.int32) + + #################################### + # pytorch + #################################### + outputs_torch = torch_model( + torch.tensor(query), past_key_values=kvcache_torch) + logit_torch = outputs_torch['logits'] + kvcache_torch = outputs_torch['past_key_values'] + + #################################### + # infinitensor + #################################### + # copyin input + (list(stub.inputs.items()))[0][1].copyin_int64( + query.reshape(-1).tolist()) + (list(stub.inputs.items()))[1][1].copyin_int64( + position_id.reshape(-1).tolist()) + + stub.run() + + #################################### + # validation + #################################### + # copyout output + logits_it = np.array((list(stub.outputs.items())) + [0][1].copyout_float()) + + try: + np.testing.assert_allclose( + logit_torch[:, -1, :].detach().cpu().numpy().flatten(), logits_it, rtol=1e-3, atol=1e-3) + except Exception as e: + try: + np.testing.assert_allclose( + np.argmax(logit_torch[:, -1, :].detach().cpu().numpy().flatten()), np.argmax(logits_it), rtol=1e-3, atol=1e-3) + except: + count_wrong = count_wrong + 1 + + result = "{}/{} failed.".format(count_wrong, args.n_max_length) + print(result) + del stub diff --git a/include/cuda/cuda_attention_kvcache.h b/include/cuda/cuda_attention_kvcache.h index 880a814f..91c65d21 100644 --- a/include/cuda/cuda_attention_kvcache.h +++ b/include/cuda/cuda_attention_kvcache.h @@ -1,4 +1,5 @@ #pragma once +#include "core/common.h" #include struct AttentionKVCacheMetadata { @@ -10,6 +11,7 @@ namespace infini { void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, float *input_v, int *position_id, float *output_matmul, - const AttentionKVCacheMetadata &compMeta); + const AttentionKVCacheMetadata &compMeta, + float *output_O_temp, float *output_sum_temp); } // namespace infini diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 415ea947..cd62ed32 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -334,7 +334,7 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), std::move(position_id), output_matmul); - return {output_matmul}; + return output_matmul; } else { return g ->addOp( diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc index 52356d8d..d72e7838 100644 --- a/src/kernels/cuda/attention_kvcache.cc +++ b/src/kernels/cuda/attention_kvcache.cc @@ -21,7 +21,7 @@ class AttentionKVCacheCompute { public: void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul) const { + Tensor output_matmul, CudaPtr p_workspace) const { AttentionKVCacheMetadata metadata; initAttentionKVCacheMetadata(metadata, input_v_cache); @@ -32,7 +32,8 @@ class AttentionKVCacheCompute { input_v->getRawDataPtr(), position_id->getRawDataPtr(), output_matmul->getRawDataPtr(), - metadata); + metadata, (float *)p_workspace, + (float *)(p_workspace + (1ll << 30))); } }; @@ -41,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute, void compute(const Operator &_op, const RuntimeObj *_context) const override { IT_ASSERT(_op->getDType() == DataType::Float32); + + size_t workspaceSize = 2ll << 30; + auto context = dynamic_cast(_context); + CudaPtr idxWsData = context->getWorkspace(workspaceSize); do_compute(_op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2], _op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5], - _op->getOutputs()[0]); + _op->getOutputs()[0], idxWsData); } }; diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index ece6659f..f169a4b1 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -2,127 +2,168 @@ #include "cuda/cuda_attention_kvcache.h" #define WARP_SIZE 32 #define BLOCKSIZE WARP_SIZE -#define SEQ_UNIT 64 +#define SEQ_UNIT 32 -__global__ void _attention_kvcache_kernel(float* input_k_cache, +// ASSUME SEQ_LEN OF Q IS 1 +__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, float* input_v_cache, float* input_q, float* input_k, float* input_v, int* position_id, - float* output_matmul, - AttentionKVCacheMetadata compMeta) { + AttentionKVCacheMetadata compMeta, + float* output_O_temp, + float* output_sum_temp) { + int seq_length = position_id[0] + 1; + int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT; + if(blockIdx.y >= stride) + return; + int lane_id = threadIdx.x % WARP_SIZE; int group_id = threadIdx.x / WARP_SIZE; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + int idx_seq = blockIdx.y * SEQ_UNIT; if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) return; - float ptr_V[SEQ_UNIT*2]; - float ptr_K[SEQ_UNIT*2]; - float ptr_Q[2]; - float ptr_P[SEQ_UNIT]; + float ptr_V[SEQ_UNIT*4]; // V + float ptr_K[SEQ_UNIT*4]; // K + float ptr_Q[4]; // Q + float ptr_P[SEQ_UNIT] = {0}; - float ptr_O[2]; - float ptr_max[1]; - float ptr_sum[1]; + float ptr_O[4] = {0}; + float ptr_sum[1] = {0}; - float ptr_max_last[1]; - float ptr_sum_last[1]; - float ptr_O_last[2]; + // readin Q + (float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)]; + int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]); - (float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)]; - - int SEQ_LENGTH = position_id[0] + 1; - - int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]); - - - for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){ - ptr_max_last[0] = ptr_max[0]; - ptr_sum_last[0] = ptr_sum[0]; - (float2 &)ptr_O_last[0] = (float2 &)ptr_O[0]; + // Q*K + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + (float4 &)ptr_K[idx_SEQ_UNIT * 4] + = (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float4 &)ptr_K[idx_SEQ_UNIT * 4] + = (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; + (float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = + (float4 &)ptr_K[idx_SEQ_UNIT * 4]; + } + #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ - (float2 &)ptr_K[idx_SEQ_UNIT * 2] - = (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; - } - else{ - (float2 &)ptr_K[idx_SEQ_UNIT * 2] - = (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; - (float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = - (float2 &)ptr_K[idx_SEQ_UNIT * 2]; - } - ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2]; - ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1]; - + for (int i = 0; i < 4; i ++){ + ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i]; #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { - ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset); + ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset); } - ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2]; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2){ - ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset); - } - ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)]; + ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i]; } - - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); - ptr_P[idx_SEQ_UNIT] /= 8; - ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]); - } - ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]); - - ptr_sum[0] = 0; - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]); - ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; - } - ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0]; - - ptr_O[0] = 0; - ptr_O[1] = 0; - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ - (float2 &)ptr_V[idx_SEQ_UNIT * 2] - = (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; - } - else{ - (float2 &)ptr_V[idx_SEQ_UNIT * 2] - = (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; - (float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = - (float2 &)ptr_V[idx_SEQ_UNIT * 2]; - } - - ptr_P[idx_SEQ_UNIT] /= ptr_sum[0]; - - ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]); - ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]); - } - ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; - ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; } - (float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0]; + + // div sqrt(d) + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); + ptr_P[idx_SEQ_UNIT] /= sqrt(128.0); + } + + // softmax + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]); + ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; + } + + // * V + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + (float4 &)ptr_V[idx_SEQ_UNIT * 4] + = (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float4 &)ptr_V[idx_SEQ_UNIT * 4] + = (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; + (float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] + = (float4 &)ptr_V[idx_SEQ_UNIT * 4]; + } + + #pragma unroll + for (int i = 0; i < 4; i ++) + ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]); + } + + #pragma unroll + for (int i = 0; i < 4; i ++) + ptr_O[i] /= ptr_sum[0]; + + (float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; + if(threadIdx.x == 0){ + output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; + } + } +__global__ void _attention_kvcache_kernel_128_2(int* position_id, + float* output_matmul, + AttentionKVCacheMetadata compMeta, + float* output_O_temp, + float* output_sum_temp) { + int lane_id = threadIdx.x % WARP_SIZE; + int group_id = threadIdx.x / WARP_SIZE; + int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + + float ptr_O[4] = {0}; + float ptr_O_sum[4] = {0}; + float ptr_sum = 0; + float ptr_sum_temp; + int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT; + + #pragma unroll + for(int i = 0; i < size; i ++){ + (float4 &)ptr_O[0] + = (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]; + ptr_sum_temp = output_sum_temp[i + parallel_idx * size]; + + #pragma unroll + for(int k = 0; k < 4; k ++) + ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp; + ptr_sum += ptr_sum_temp; + } + + #pragma unroll + for(int k = 0; k < 4; k ++) + ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum; + + (float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0]; + +} + + namespace infini { -void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, - float *input_v, int *position_id, float *output_matmul, - const AttentionKVCacheMetadata &compMeta) { - IT_ASSERT(compMeta.dimSize[3] == 64); - dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1); +void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, + float *input_q, float *input_k, + float *input_v, int *position_id, float *output_matmul, + const AttentionKVCacheMetadata &compMeta, + float *output_O_temp, float *output_sum_temp) { + IT_ASSERT(compMeta.dimSize[3] == 128); + + int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT; + dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y); dim3 blockDim(BLOCKSIZE, 1); - _attention_kvcache_kernel<<>>( - input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta); + assert(compMeta.dimSize[3] == 128); + _attention_kvcache_kernel_128_1<<>>( + input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, + compMeta, output_O_temp, output_sum_temp); + _attention_kvcache_kernel_128_2<<>>( + position_id, output_matmul, compMeta, output_O_temp, output_sum_temp); + } } // namespace infini diff --git a/test/kernels/cuda/test_cuda_attention.cc b/test/kernels/cuda/test_cuda_attention.cc index 3ccf861d..3a9bff45 100644 --- a/test/kernels/cuda/test_cuda_attention.cc +++ b/test/kernels/cuda/test_cuda_attention.cc @@ -14,11 +14,11 @@ TEST(AttentionKVCache, Cuda) { auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); - auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); auto op = gCuda->addOp( @@ -32,11 +32,14 @@ TEST(AttentionKVCache, Cuda) { position_id_d->setData(IncrementalGenerator()); cudaRuntime->run(gCuda); - auto oCpu = gCpu->cloneTensor(op->getOutput()); + auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]); EXPECT_TRUE(oCpu->equalData(vector{ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); } } // namespace infini