From 6a1bfd6c45e185c1d3458e30d61a4bb7285ad46b Mon Sep 17 00:00:00 2001 From: xiaonans Date: Wed, 17 Jan 2024 11:26:05 +0800 Subject: [PATCH 1/6] [feature] support kvcache with static graph --- examples/python/llama_kvcache_inference.py | 146 +++++++++++++ include/core/graph_handler.h | 7 +- include/cuda/cuda_attention_kvcache.h | 3 +- include/operators/attention_kvcache.h | 7 +- pyinfinitensor/src/pyinfinitensor/onnx.py | 4 +- src/core/graph_handler.cc | 21 +- src/core/operator.cc | 6 +- src/kernels/cuda/attention_kvcache.cc | 15 +- src/kernels/cuda/attention_kvcache.cu | 227 ++++++++++++--------- src/operators/attention_kvcache.cc | 13 +- test/kernels/cuda/test_cuda_attention.cc | 19 +- 11 files changed, 335 insertions(+), 133 deletions(-) create mode 100644 examples/python/llama_kvcache_inference.py diff --git a/examples/python/llama_kvcache_inference.py b/examples/python/llama_kvcache_inference.py new file mode 100644 index 00000000..b05339b8 --- /dev/null +++ b/examples/python/llama_kvcache_inference.py @@ -0,0 +1,146 @@ +import os +from pyinfinitensor.onnx import OnnxStub, backend +import numpy as np +import onnx +import torch +from transformers import LlamaModel, LlamaForCausalLM +from tqdm import tqdm +import onnx_graphsurgeon as gs +from onnxsim import simplify +import argparse + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--batchsize', dest='batchsize', type=int, default=1) +parser.add_argument('--layer', dest='n_layers', type=int, default=2) +parser.add_argument('--iter', dest='n_iter', type=int, default=1) +parser.add_argument('--n_max_length', dest='n_max_length', type=int, default=1024) +parser.add_argument('--pretrained_llama_path', dest='pretrained_llama_path', type=str, + default="/data0/shared/data/public/opensource_models/meta-llama/Llama-2-7b-hf/") +parser.add_argument('--onnx_model_path', dest='onnx_model_path', type=str, + default="/data1/shared/llama") +args = parser.parse_args() + +ONNX_MODEL_PATH = "{}/llama_bs{}_layer{}.onnx".format(args.onnx_model_path, args.batchsize, args.n_layers) +ONNX_WEIGHT_PATH = "./llama_bs{}_layer{}.pb".format(args.batchsize, args.n_layers) + +def export_onnx(model: LlamaModel, ONNX_MODEL_PATH): + param = torch.zeros( + (args.batchsize, 1024), dtype=torch.long) + logits = model(param, past_key_values=None) + param_kvcache = torch.zeros((args.batchsize, 1), dtype=torch.long) + + torch.onnx.export(model, (param_kvcache, {"past_key_values": logits.past_key_values, + "position_ids": param_kvcache}), ONNX_MODEL_PATH, verbose=False, + do_constant_folding=True,) + onnx_model = onnx.load(ONNX_MODEL_PATH) + print("simplifing onnx model") + onnx_model, check = simplify(onnx_model, skipped_optimizers=[ + 'eliminate_duplicate_initializer']) + assert check + + onnx.save(onnx_model, ONNX_MODEL_PATH, save_as_external_data=True, location=ONNX_WEIGHT_PATH) + print("simlifing finished.") + + +@gs.Graph.register() +def replace_with_attention(self, inputs, outputs, inputs_added, outputs_removed): + for inp in inputs: + inp.outputs.clear() + for out in outputs: + out.inputs.clear() + for inp in inputs_added: + inputs.append(inp) + for out in outputs_removed: + out.inputs.clear() + return self.layer(op="AttentionKVCache", inputs=inputs, outputs=outputs) + + +def replace_onnx_with_attention_op(): + graph = gs.import_onnx( + onnx.load(ONNX_MODEL_PATH)) + tmap = graph.tensors() + for i in range(args.n_layers): + inputs = [ + tmap["onnx::Concat_" + str((i+1)*2)], + tmap["onnx::Concat_" + str((i+1)*2+1)], + tmap["/model/layers." + str(i) + "/self_attn/Add_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"], + tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]] + outputs = [ + tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"], + tmap[graph.outputs[1+i*2].name], + tmap[graph.outputs[2+i*2].name]] + + inputs_added = [graph.inputs[1]] + outputs_removed = [] + + graph.replace_with_attention( + inputs, outputs, inputs_added, outputs_removed) + + graph.cleanup(True).toposort() + onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True) + + +if __name__ == "__main__": + kvcache_torch = None + torch_model = LlamaForCausalLM.from_pretrained( + args.pretrained_llama_path, num_hidden_layers=int(args.n_layers)).eval() + + n_heads = torch_model.config.num_attention_heads + n_dims = torch_model.config.hidden_size // n_heads + + if not os.path.exists(ONNX_MODEL_PATH): + print("exporting onnx graph") + export_onnx(torch_model, ONNX_MODEL_PATH) + replace_onnx_with_attention_op() + else: + print("will use exsiting onnx graph") + + onnx_model = onnx.load(ONNX_MODEL_PATH) + stub = OnnxStub(onnx_model, backend.cuda_runtime()) + + count_wrong = 0 + for i in tqdm(range(0, args.n_max_length)): + query = np.random.randint( + torch_model.config.vocab_size, size=(args.batchsize, 1), dtype=np.int32) + position_id = i*np.ones((args.batchsize, 1), dtype=np.int32) + + #################################### + # pytorch + #################################### + outputs_torch = torch_model( + torch.tensor(query), past_key_values=kvcache_torch) + logit_torch = outputs_torch['logits'] + kvcache_torch = outputs_torch['past_key_values'] + + #################################### + # infinitensor + #################################### + # copyin input + (list(stub.inputs.items()))[0][1].copyin_int64( + query.reshape(-1).tolist()) + (list(stub.inputs.items()))[1][1].copyin_int64( + position_id.reshape(-1).tolist()) + + stub.run() + + #################################### + # validation + #################################### + # copyout output + logits_it = np.array((list(stub.outputs.items())) + [0][1].copyout_float()) + + try: + np.testing.assert_allclose( + logit_torch[:, -1, :].detach().cpu().numpy().flatten(), logits_it, rtol=1e-3, atol=1e-3) + except Exception as e: + try: + np.testing.assert_allclose( + np.argmax(logit_torch[:, -1, :].detach().cpu().numpy().flatten()), np.argmax(logits_it), rtol=1e-3, atol=1e-3) + except: + count_wrong = count_wrong + 1 + + result = "{}/{} failed.".format(count_wrong, args.n_max_length) + print(result) + del stub diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 0e1472bb..75673d14 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -74,9 +74,10 @@ class GraphHandlerObj { Tensor squeeze(Tensor input, Tensor output, Shape axes); Tensor unsqueeze(Tensor input, Tensor output, Shape axes); Tensor concat(TensorVec inputs, Tensor output, int dim); - Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, - Tensor input_q, Tensor input_k, Tensor input_v, - Tensor position_id, Tensor output_matmul); + TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, + Tensor input_q, Tensor input_k, Tensor input_v, + Tensor position_id, Tensor output_matmul, + Tensor output_k_cache, Tensor output_v_cache); TensorVec split(Tensor input, std::optional outputs, int axis, std::variant> numOrRatio); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/cuda/cuda_attention_kvcache.h b/include/cuda/cuda_attention_kvcache.h index 880a814f..74d356c9 100644 --- a/include/cuda/cuda_attention_kvcache.h +++ b/include/cuda/cuda_attention_kvcache.h @@ -10,6 +10,7 @@ namespace infini { void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, float *input_v, int *position_id, float *output_matmul, - const AttentionKVCacheMetadata &compMeta); + const AttentionKVCacheMetadata &compMeta, + float *output_O_temp, float *output_sum_temp); } // namespace infini diff --git a/include/operators/attention_kvcache.h b/include/operators/attention_kvcache.h index 0472b222..b4448511 100644 --- a/include/operators/attention_kvcache.h +++ b/include/operators/attention_kvcache.h @@ -22,18 +22,21 @@ class AttentionKVCacheObj : public OperatorObj { * @param input_v The value input tensor. * @param position_id The positon id of the query, * @param output_matmul The query output tensor. + * @param output_k_cache The output k_cache tensor. + * @param output_v_cache The output v_cache tensor. */ AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul); + Tensor output_matmul, Tensor output_k_cache, + Tensor output_v_cache); OP_CLONE(AttentionKVCacheObj); optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 6; } - int numOutputs() const override { return 1; } + int numOutputs() const override { return 3; } int getDim() const { return dim; } private: diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 79abb7f4..5a6f62fc 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -660,7 +660,7 @@ class OnnxStub: next((attr.i for attr in node.attribute if attr.name == "axis")), ) elif node.op_type == "AttentionKVCache": - tensors[node.output[0]] = self.handler.attentionKVCache( + tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache( tensors[node.input[0]], tensors[node.input[1]], tensors[node.input[2]], @@ -668,6 +668,8 @@ class OnnxStub: tensors[node.input[4]], tensors[node.input[5]], tensors.get(node.output[0]), + tensors.get(node.output[1]), + tensors.get(node.output[2]), ) elif node.op_type == "Split": split = ( diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 415ea947..18eb893f 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -324,24 +324,25 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { } } -Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, - Tensor input_v_cache, Tensor input_q, - Tensor input_k, Tensor input_v, - Tensor position_id, - Tensor output_matmul) { - if (output_matmul) { +TensorVec GraphHandlerObj::attentionKVCache( + Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, + Tensor input_v, Tensor position_id, Tensor output_matmul, + Tensor output_k_cache, Tensor output_v_cache) { + if (output_matmul && output_k_cache && output_v_cache) { g->addOpWithOutputs( std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), - std::move(position_id), output_matmul); - return {output_matmul}; + std::move(position_id), output_matmul, output_k_cache, + output_v_cache); + return {output_matmul, output_k_cache, output_v_cache}; } else { return g ->addOp( std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), - std::move(position_id), output_matmul) - ->getOutput(); + std::move(position_id), output_matmul, output_k_cache, + output_v_cache) + ->getOutputs(); } } diff --git a/src/core/operator.cc b/src/core/operator.cc index 4fd4e6de..9a7cf6e0 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -67,8 +67,10 @@ bool OperatorObj::checkValid(GraphObj *graph) { if (graph) { // if graph != nullptr, outputs should be created auto dataTypes = inferDataType(); for (size_t i = 0; i < outputs.size(); i++) { - IT_ASSERT(!outputs[i], "Find empty output while operator creation"); - outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); + if (!outputs[i]) + outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); + else if (shapes[i] != outputs[i]->getDims()) + return false; } } else { // if outputs have been created, check their shapes for (size_t i = 0; i < shapes.size(); ++i) { diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc index 52356d8d..8ecff414 100644 --- a/src/kernels/cuda/attention_kvcache.cc +++ b/src/kernels/cuda/attention_kvcache.cc @@ -21,7 +21,8 @@ class AttentionKVCacheCompute { public: void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul) const { + Tensor output_matmul, Tensor output_temp_O, + Tensor output_temp_sum) const { AttentionKVCacheMetadata metadata; initAttentionKVCacheMetadata(metadata, input_v_cache); @@ -32,7 +33,9 @@ class AttentionKVCacheCompute { input_v->getRawDataPtr(), position_id->getRawDataPtr(), output_matmul->getRawDataPtr(), - metadata); + metadata, + output_temp_O->getRawDataPtr(), + output_temp_sum->getRawDataPtr()); } }; @@ -41,10 +44,10 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute, void compute(const Operator &_op, const RuntimeObj *_context) const override { IT_ASSERT(_op->getDType() == DataType::Float32); - do_compute(_op->getInputs()[0], _op->getInputs()[1], - _op->getInputs()[2], _op->getInputs()[3], - _op->getInputs()[4], _op->getInputs()[5], - _op->getOutputs()[0]); + do_compute( + _op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2], + _op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5], + _op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]); } }; diff --git a/src/kernels/cuda/attention_kvcache.cu b/src/kernels/cuda/attention_kvcache.cu index ece6659f..f169a4b1 100644 --- a/src/kernels/cuda/attention_kvcache.cu +++ b/src/kernels/cuda/attention_kvcache.cu @@ -2,127 +2,168 @@ #include "cuda/cuda_attention_kvcache.h" #define WARP_SIZE 32 #define BLOCKSIZE WARP_SIZE -#define SEQ_UNIT 64 +#define SEQ_UNIT 32 -__global__ void _attention_kvcache_kernel(float* input_k_cache, +// ASSUME SEQ_LEN OF Q IS 1 +__global__ void _attention_kvcache_kernel_128_1(float* input_k_cache, float* input_v_cache, float* input_q, float* input_k, float* input_v, int* position_id, - float* output_matmul, - AttentionKVCacheMetadata compMeta) { + AttentionKVCacheMetadata compMeta, + float* output_O_temp, + float* output_sum_temp) { + int seq_length = position_id[0] + 1; + int stride = (seq_length + SEQ_UNIT - 1) / SEQ_UNIT; + if(blockIdx.y >= stride) + return; + int lane_id = threadIdx.x % WARP_SIZE; int group_id = threadIdx.x / WARP_SIZE; int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + int idx_seq = blockIdx.y * SEQ_UNIT; if(parallel_idx >= compMeta.dimSize[0] * compMeta.dimSize[1]) return; - float ptr_V[SEQ_UNIT*2]; - float ptr_K[SEQ_UNIT*2]; - float ptr_Q[2]; - float ptr_P[SEQ_UNIT]; + float ptr_V[SEQ_UNIT*4]; // V + float ptr_K[SEQ_UNIT*4]; // K + float ptr_Q[4]; // Q + float ptr_P[SEQ_UNIT] = {0}; - float ptr_O[2]; - float ptr_max[1]; - float ptr_sum[1]; + float ptr_O[4] = {0}; + float ptr_sum[1] = {0}; - float ptr_max_last[1]; - float ptr_sum_last[1]; - float ptr_O_last[2]; + // readin Q + (float4 &)ptr_Q[0] = (float4 &)input_q[(lane_id * 4) + (parallel_idx * 128)]; + int common_idx = (lane_id * 4) + (parallel_idx * compMeta.stride[1]); - (float2 &)ptr_Q[0] = (float2 &)input_q[(lane_id * 2) + (parallel_idx * 64)]; - - int SEQ_LENGTH = position_id[0] + 1; - - int common_idx = (lane_id * 2) + (parallel_idx * compMeta.stride[1]); - - - for (int idx_seq = 0; idx_seq < SEQ_LENGTH; idx_seq += SEQ_UNIT){ - ptr_max_last[0] = ptr_max[0]; - ptr_sum_last[0] = ptr_sum[0]; - (float2 &)ptr_O_last[0] = (float2 &)ptr_O[0]; + // Q*K + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + (float4 &)ptr_K[idx_SEQ_UNIT * 4] + = (float4 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float4 &)ptr_K[idx_SEQ_UNIT * 4] + = (float4 &) input_k[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; + (float4 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = + (float4 &)ptr_K[idx_SEQ_UNIT * 4]; + } + #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ - (float2 &)ptr_K[idx_SEQ_UNIT * 2] - = (float2 &) input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; - } - else{ - (float2 &)ptr_K[idx_SEQ_UNIT * 2] - = (float2 &) input_k[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; - (float2 &)input_k_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = - (float2 &)ptr_K[idx_SEQ_UNIT * 2]; - } - ptr_K[idx_SEQ_UNIT * 2] = ptr_Q[0] * ptr_K[idx_SEQ_UNIT * 2]; - ptr_K[idx_SEQ_UNIT * 2 + 1] = ptr_Q[1] * ptr_K[idx_SEQ_UNIT * 2 + 1]; - + for (int i = 0; i < 4; i ++){ + ptr_K[idx_SEQ_UNIT * 4 + i] = ptr_Q[i] * ptr_K[idx_SEQ_UNIT * 4 + i]; #pragma unroll for (int offset = 16; offset > 0; offset /= 2) { - ptr_K[idx_SEQ_UNIT * 2] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 2], offset); + ptr_K[idx_SEQ_UNIT * 4 + i] += __shfl_down_sync(0xffffffff, ptr_K[idx_SEQ_UNIT * 4 + i], offset); } - ptr_P[idx_SEQ_UNIT] = ptr_K[idx_SEQ_UNIT * 2]; - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2){ - ptr_K[((idx_SEQ_UNIT * 2) + 1)] += __shfl_down_sync(0xffffffff, ptr_K[((idx_SEQ_UNIT * 2) + 1)], offset); - } - ptr_P[idx_SEQ_UNIT] += ptr_K[((idx_SEQ_UNIT * 2) + 1)]; + ptr_P[idx_SEQ_UNIT] += ptr_K[idx_SEQ_UNIT * 4 + i]; } - - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); - ptr_P[idx_SEQ_UNIT] /= 8; - ptr_max[0] = (idx_SEQ_UNIT == 0) ? ptr_P[0] : max(ptr_max[0], ptr_P[idx_SEQ_UNIT]); - } - ptr_max[0] = (idx_seq == 0) ? ptr_max[0] : max(ptr_max[0], ptr_max_last[0]); - - ptr_sum[0] = 0; - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT] - ptr_max[0]); - ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; - } - ptr_sum[0] = (idx_seq == 0) ? ptr_sum[0] : expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] + ptr_sum[0]; - - ptr_O[0] = 0; - ptr_O[1] = 0; - #pragma unroll - for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < SEQ_LENGTH; idx_SEQ_UNIT ++) { - if(idx_SEQ_UNIT + idx_seq < SEQ_LENGTH - 1){ - (float2 &)ptr_V[idx_SEQ_UNIT * 2] - = (float2 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; - } - else{ - (float2 &)ptr_V[idx_SEQ_UNIT * 2] - = (float2 &) input_v[((lane_id * 2) + parallel_idx * compMeta.stride[2])]; - (float2 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] = - (float2 &)ptr_V[idx_SEQ_UNIT * 2]; - } - - ptr_P[idx_SEQ_UNIT] /= ptr_sum[0]; - - ptr_O[0] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2)], ptr_O[0]); - ptr_O[1] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 2) + 1], ptr_O[1]); - } - ptr_O[0] = (idx_seq == 0) ? ptr_O[0] : ptr_O[0] + ptr_O_last[0] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; - ptr_O[1] = (idx_seq == 0) ? ptr_O[1] : ptr_O[1] + ptr_O_last[1] * expf(ptr_max_last[0] - ptr_max[0]) * ptr_sum_last[0] / ptr_sum[0]; } - (float2 &)output_matmul[(lane_id * 2) + (parallel_idx * compMeta.dimSize[3])] = (float2 &)ptr_O[0]; + + // div sqrt(d) + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0); + ptr_P[idx_SEQ_UNIT] /= sqrt(128.0); + } + + // softmax + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + ptr_P[idx_SEQ_UNIT] = expf(ptr_P[idx_SEQ_UNIT]); + ptr_sum[0] += ptr_P[idx_SEQ_UNIT]; + } + + // * V + #pragma unroll + for (int idx_SEQ_UNIT = 0; idx_SEQ_UNIT < SEQ_UNIT && idx_SEQ_UNIT + idx_seq < seq_length; idx_SEQ_UNIT ++) { + if(idx_SEQ_UNIT + idx_seq < seq_length - 1){ + (float4 &)ptr_V[idx_SEQ_UNIT * 4] + = (float4 &) input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])]; + } + else{ + (float4 &)ptr_V[idx_SEQ_UNIT * 4] + = (float4 &) input_v[((lane_id * 4) + parallel_idx * compMeta.stride[2])]; + (float4 &)input_v_cache[common_idx + ((idx_SEQ_UNIT + idx_seq) * compMeta.stride[2])] + = (float4 &)ptr_V[idx_SEQ_UNIT * 4]; + } + + #pragma unroll + for (int i = 0; i < 4; i ++) + ptr_O[i] = fmaf(ptr_P[idx_SEQ_UNIT], ptr_V[(idx_SEQ_UNIT * 4 + i)], ptr_O[i]); + } + + #pragma unroll + for (int i = 0; i < 4; i ++) + ptr_O[i] /= ptr_sum[0]; + + (float4 &)output_O_temp[(lane_id * 4) + (blockIdx.y * compMeta.dimSize[3]) + (parallel_idx * compMeta.dimSize[3] * stride)] = (float4 &)ptr_O[0]; + if(threadIdx.x == 0){ + output_sum_temp[blockIdx.y + parallel_idx * stride] = ptr_sum[0]; + } + } +__global__ void _attention_kvcache_kernel_128_2(int* position_id, + float* output_matmul, + AttentionKVCacheMetadata compMeta, + float* output_O_temp, + float* output_sum_temp) { + int lane_id = threadIdx.x % WARP_SIZE; + int group_id = threadIdx.x / WARP_SIZE; + int parallel_idx = blockIdx.x * (blockDim.x / WARP_SIZE) + group_id; + + float ptr_O[4] = {0}; + float ptr_O_sum[4] = {0}; + float ptr_sum = 0; + float ptr_sum_temp; + int size = (position_id[0] + SEQ_UNIT) / SEQ_UNIT; + + #pragma unroll + for(int i = 0; i < size; i ++){ + (float4 &)ptr_O[0] + = (float4 &)output_O_temp[(lane_id * 4) + (i * compMeta.dimSize[3]) + parallel_idx * compMeta.dimSize[3] * size]; + ptr_sum_temp = output_sum_temp[i + parallel_idx * size]; + + #pragma unroll + for(int k = 0; k < 4; k ++) + ptr_O_sum[k] += ptr_O[k] * ptr_sum_temp; + ptr_sum += ptr_sum_temp; + } + + #pragma unroll + for(int k = 0; k < 4; k ++) + ptr_O_sum[k] = ptr_O_sum[k] / ptr_sum; + + (float4 &)output_matmul[(lane_id * 4) + (parallel_idx * compMeta.dimSize[3])] = (float4 &)ptr_O_sum[0]; + +} + + namespace infini { -void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, float *input_q, float *input_k, - float *input_v, int *position_id, float *output_matmul, - const AttentionKVCacheMetadata &compMeta) { - IT_ASSERT(compMeta.dimSize[3] == 64); - dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), 1); +void attention_kvcache_kernel(float *input_k_cache, float *input_v_cache, + float *input_q, float *input_k, + float *input_v, int *position_id, float *output_matmul, + const AttentionKVCacheMetadata &compMeta, + float *output_O_temp, float *output_sum_temp) { + IT_ASSERT(compMeta.dimSize[3] == 128); + + int gridsize_y = (compMeta.dimSize[2] - 1 + SEQ_UNIT) / SEQ_UNIT; + dim3 gridDim(compMeta.dimSize[0]*compMeta.dimSize[1]/(BLOCKSIZE/WARP_SIZE), gridsize_y); dim3 blockDim(BLOCKSIZE, 1); - _attention_kvcache_kernel<<>>( - input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, output_matmul, compMeta); + assert(compMeta.dimSize[3] == 128); + _attention_kvcache_kernel_128_1<<>>( + input_k_cache, input_v_cache, input_q, input_k, input_v, position_id, + compMeta, output_O_temp, output_sum_temp); + _attention_kvcache_kernel_128_2<<>>( + position_id, output_matmul, compMeta, output_O_temp, output_sum_temp); + } } // namespace infini diff --git a/src/operators/attention_kvcache.cc b/src/operators/attention_kvcache.cc index 492a76f7..24c3ba2d 100644 --- a/src/operators/attention_kvcache.cc +++ b/src/operators/attention_kvcache.cc @@ -2,15 +2,14 @@ #include "utils/operator_utils.h" namespace infini { -AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, - Tensor input_v_cache, Tensor input_q, - Tensor input_k, Tensor input_v, - Tensor position_id, - Tensor output_matmul) +AttentionKVCacheObj::AttentionKVCacheObj( + GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul, + Tensor output_k_cache, Tensor output_v_cache) : OperatorObj(OpType::AttentionKVCache, TensorVec{input_k_cache, input_v_cache, input_q, input_k, input_v, position_id}, - {output_matmul}) { + TensorVec{output_matmul, output_k_cache, output_v_cache}) { int rank = inputs[0]->getRank(); IT_ASSERT(rank == 4); dim = 2; @@ -23,7 +22,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) { Shape dims = inputs[0]->getDims(); ShapeElem n = dims.at(dim); dims[dim] = n + 1; - return {{inputs[2]->getDims()}}; + return {{inputs[2]->getDims(), dims, dims}}; } std::string AttentionKVCacheObj::toString() const { diff --git a/test/kernels/cuda/test_cuda_attention.cc b/test/kernels/cuda/test_cuda_attention.cc index 3ccf861d..b95470f4 100644 --- a/test/kernels/cuda/test_cuda_attention.cc +++ b/test/kernels/cuda/test_cuda_attention.cc @@ -14,16 +14,16 @@ TEST(AttentionKVCache, Cuda) { auto cudaRuntime = make_ref(); Graph gCuda = make_ref(cudaRuntime); - auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_q_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_k_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); - auto input_v_d = gCuda->addTensor({1, 1, 1, 64}, DataType::Float32); + auto input_k_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_cache_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_q_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_k_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); + auto input_v_d = gCuda->addTensor({1, 1, 1, 128}, DataType::Float32); auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); auto op = gCuda->addOp( input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d, - position_id_d, nullptr); + position_id_d, nullptr, nullptr, nullptr); gCuda->dataMalloc(); input_q_d->setData(OneGenerator()); @@ -32,11 +32,14 @@ TEST(AttentionKVCache, Cuda) { position_id_d->setData(IncrementalGenerator()); cudaRuntime->run(gCuda); - auto oCpu = gCpu->cloneTensor(op->getOutput()); + auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]); EXPECT_TRUE(oCpu->equalData(vector{ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); } } // namespace infini From afed5d3c3d87841e4275bde9a461cb2d414d73c1 Mon Sep 17 00:00:00 2001 From: xiaonans Date: Thu, 25 Jan 2024 09:08:25 +0800 Subject: [PATCH 2/6] use workspace to optimize kvcache attention --- examples/python/llama_kvcache_inference.py | 7 +++---- include/core/graph_handler.h | 7 +++---- include/cuda/cuda_attention_kvcache.h | 1 + include/operators/attention_kvcache.h | 7 ++----- pyinfinitensor/src/pyinfinitensor/onnx.py | 4 +--- src/core/graph_handler.cc | 21 ++++++++++----------- src/core/operator.cc | 6 ++---- src/kernels/cuda/attention_kvcache.cc | 20 +++++++++++--------- src/operators/attention_kvcache.cc | 13 +++++++------ test/kernels/cuda/test_cuda_attention.cc | 2 +- 10 files changed, 41 insertions(+), 47 deletions(-) diff --git a/examples/python/llama_kvcache_inference.py b/examples/python/llama_kvcache_inference.py index b05339b8..e6ba67ff 100644 --- a/examples/python/llama_kvcache_inference.py +++ b/examples/python/llama_kvcache_inference.py @@ -67,16 +67,15 @@ def replace_onnx_with_attention_op(): tmap["/model/layers." + str(i) + "/self_attn/Add_1_output_0"], tmap["/model/layers." + str(i) + "/self_attn/Transpose_2_output_0"]] outputs = [ - tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"], - tmap[graph.outputs[1+i*2].name], - tmap[graph.outputs[2+i*2].name]] + tmap["/model/layers." + str(i) + "/self_attn/MatMul_1_output_0"]] inputs_added = [graph.inputs[1]] outputs_removed = [] graph.replace_with_attention( inputs, outputs, inputs_added, outputs_removed) - + + graph.outputs = [tmap[graph.outputs[0].name]] graph.cleanup(True).toposort() onnx.save(gs.export_onnx(graph), ONNX_MODEL_PATH, save_as_external_data=True) diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 75673d14..0e1472bb 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -74,10 +74,9 @@ class GraphHandlerObj { Tensor squeeze(Tensor input, Tensor output, Shape axes); Tensor unsqueeze(Tensor input, Tensor output, Shape axes); Tensor concat(TensorVec inputs, Tensor output, int dim); - TensorVec attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, - Tensor input_q, Tensor input_k, Tensor input_v, - Tensor position_id, Tensor output_matmul, - Tensor output_k_cache, Tensor output_v_cache); + Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, + Tensor input_q, Tensor input_k, Tensor input_v, + Tensor position_id, Tensor output_matmul); TensorVec split(Tensor input, std::optional outputs, int axis, std::variant> numOrRatio); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/cuda/cuda_attention_kvcache.h b/include/cuda/cuda_attention_kvcache.h index 74d356c9..91c65d21 100644 --- a/include/cuda/cuda_attention_kvcache.h +++ b/include/cuda/cuda_attention_kvcache.h @@ -1,4 +1,5 @@ #pragma once +#include "core/common.h" #include struct AttentionKVCacheMetadata { diff --git a/include/operators/attention_kvcache.h b/include/operators/attention_kvcache.h index b4448511..0472b222 100644 --- a/include/operators/attention_kvcache.h +++ b/include/operators/attention_kvcache.h @@ -22,21 +22,18 @@ class AttentionKVCacheObj : public OperatorObj { * @param input_v The value input tensor. * @param position_id The positon id of the query, * @param output_matmul The query output tensor. - * @param output_k_cache The output k_cache tensor. - * @param output_v_cache The output v_cache tensor. */ AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul, Tensor output_k_cache, - Tensor output_v_cache); + Tensor output_matmul); OP_CLONE(AttentionKVCacheObj); optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 6; } - int numOutputs() const override { return 3; } + int numOutputs() const override { return 1; } int getDim() const { return dim; } private: diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 5a6f62fc..79abb7f4 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -660,7 +660,7 @@ class OnnxStub: next((attr.i for attr in node.attribute if attr.name == "axis")), ) elif node.op_type == "AttentionKVCache": - tensors[node.output[0]], tensors[node.output[1]], tensors[node.output[2]] = self.handler.attentionKVCache( + tensors[node.output[0]] = self.handler.attentionKVCache( tensors[node.input[0]], tensors[node.input[1]], tensors[node.input[2]], @@ -668,8 +668,6 @@ class OnnxStub: tensors[node.input[4]], tensors[node.input[5]], tensors.get(node.output[0]), - tensors.get(node.output[1]), - tensors.get(node.output[2]), ) elif node.op_type == "Split": split = ( diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 18eb893f..cd62ed32 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -324,25 +324,24 @@ Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) { } } -TensorVec GraphHandlerObj::attentionKVCache( - Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, - Tensor input_v, Tensor position_id, Tensor output_matmul, - Tensor output_k_cache, Tensor output_v_cache) { - if (output_matmul && output_k_cache && output_v_cache) { +Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) { + if (output_matmul) { g->addOpWithOutputs( std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), - std::move(position_id), output_matmul, output_k_cache, - output_v_cache); - return {output_matmul, output_k_cache, output_v_cache}; + std::move(position_id), output_matmul); + return output_matmul; } else { return g ->addOp( std::move(input_k_cache), std::move(input_v_cache), std::move(input_q), std::move(input_k), std::move(input_v), - std::move(position_id), output_matmul, output_k_cache, - output_v_cache) - ->getOutputs(); + std::move(position_id), output_matmul) + ->getOutput(); } } diff --git a/src/core/operator.cc b/src/core/operator.cc index 9a7cf6e0..4fd4e6de 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -67,10 +67,8 @@ bool OperatorObj::checkValid(GraphObj *graph) { if (graph) { // if graph != nullptr, outputs should be created auto dataTypes = inferDataType(); for (size_t i = 0; i < outputs.size(); i++) { - if (!outputs[i]) - outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); - else if (shapes[i] != outputs[i]->getDims()) - return false; + IT_ASSERT(!outputs[i], "Find empty output while operator creation"); + outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); } } else { // if outputs have been created, check their shapes for (size_t i = 0; i < shapes.size(); ++i) { diff --git a/src/kernels/cuda/attention_kvcache.cc b/src/kernels/cuda/attention_kvcache.cc index 8ecff414..d72e7838 100644 --- a/src/kernels/cuda/attention_kvcache.cc +++ b/src/kernels/cuda/attention_kvcache.cc @@ -21,8 +21,7 @@ class AttentionKVCacheCompute { public: void do_compute(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, - Tensor output_matmul, Tensor output_temp_O, - Tensor output_temp_sum) const { + Tensor output_matmul, CudaPtr p_workspace) const { AttentionKVCacheMetadata metadata; initAttentionKVCacheMetadata(metadata, input_v_cache); @@ -33,9 +32,8 @@ class AttentionKVCacheCompute { input_v->getRawDataPtr(), position_id->getRawDataPtr(), output_matmul->getRawDataPtr(), - metadata, - output_temp_O->getRawDataPtr(), - output_temp_sum->getRawDataPtr()); + metadata, (float *)p_workspace, + (float *)(p_workspace + (1ll << 30))); } }; @@ -44,10 +42,14 @@ class AttentionKVCacheCuda : private AttentionKVCacheCompute, void compute(const Operator &_op, const RuntimeObj *_context) const override { IT_ASSERT(_op->getDType() == DataType::Float32); - do_compute( - _op->getInputs()[0], _op->getInputs()[1], _op->getInputs()[2], - _op->getInputs()[3], _op->getInputs()[4], _op->getInputs()[5], - _op->getOutputs()[0], _op->getOutputs()[1], _op->getOutputs()[2]); + + size_t workspaceSize = 2ll << 30; + auto context = dynamic_cast(_context); + CudaPtr idxWsData = context->getWorkspace(workspaceSize); + do_compute(_op->getInputs()[0], _op->getInputs()[1], + _op->getInputs()[2], _op->getInputs()[3], + _op->getInputs()[4], _op->getInputs()[5], + _op->getOutputs()[0], idxWsData); } }; diff --git a/src/operators/attention_kvcache.cc b/src/operators/attention_kvcache.cc index 24c3ba2d..492a76f7 100644 --- a/src/operators/attention_kvcache.cc +++ b/src/operators/attention_kvcache.cc @@ -2,14 +2,15 @@ #include "utils/operator_utils.h" namespace infini { -AttentionKVCacheObj::AttentionKVCacheObj( - GraphObj *graph, Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, - Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul, - Tensor output_k_cache, Tensor output_v_cache) +AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, + Tensor input_v_cache, Tensor input_q, + Tensor input_k, Tensor input_v, + Tensor position_id, + Tensor output_matmul) : OperatorObj(OpType::AttentionKVCache, TensorVec{input_k_cache, input_v_cache, input_q, input_k, input_v, position_id}, - TensorVec{output_matmul, output_k_cache, output_v_cache}) { + {output_matmul}) { int rank = inputs[0]->getRank(); IT_ASSERT(rank == 4); dim = 2; @@ -22,7 +23,7 @@ AttentionKVCacheObj::inferShape(const TensorVec &inputs) { Shape dims = inputs[0]->getDims(); ShapeElem n = dims.at(dim); dims[dim] = n + 1; - return {{inputs[2]->getDims(), dims, dims}}; + return {{inputs[2]->getDims()}}; } std::string AttentionKVCacheObj::toString() const { diff --git a/test/kernels/cuda/test_cuda_attention.cc b/test/kernels/cuda/test_cuda_attention.cc index b95470f4..3a9bff45 100644 --- a/test/kernels/cuda/test_cuda_attention.cc +++ b/test/kernels/cuda/test_cuda_attention.cc @@ -23,7 +23,7 @@ TEST(AttentionKVCache, Cuda) { auto op = gCuda->addOp( input_k_cache_d, input_v_cache_d, input_q_d, input_k_d, input_v_d, - position_id_d, nullptr, nullptr, nullptr); + position_id_d, nullptr); gCuda->dataMalloc(); input_q_d->setData(OneGenerator()); From e8d111ef5d0c34eca072959a4dcc21e14d7534fc Mon Sep 17 00:00:00 2001 From: xiaonans Date: Thu, 11 Jan 2024 15:44:07 +0800 Subject: [PATCH 3/6] add rope and silu support --- include/core/graph_handler.h | 2 + include/core/op_type.h | 2 + include/cuda/cuda_rope.h | 10 +++ include/cuda/cuda_unary.h | 1 + include/operators/rope.h | 21 ++++++ include/operators/unary.h | 1 + pyinfinitensor/src/pyinfinitensor/onnx.py | 11 +++ src/core/graph_handler.cc | 14 +++- src/ffi/ffi_infinitensor.cc | 2 + src/kernels/cuda/rope.cc | 38 ++++++++++ src/kernels/cuda/rope.cu | 91 +++++++++++++++++++++++ src/kernels/cuda/unary.cc | 2 + src/kernels/cuda/unary.cu | 26 +++++++ src/operators/rope.cc | 37 +++++++++ 14 files changed, 257 insertions(+), 1 deletion(-) create mode 100644 include/cuda/cuda_rope.h create mode 100644 include/operators/rope.h create mode 100644 src/kernels/cuda/rope.cc create mode 100644 src/kernels/cuda/rope.cu create mode 100644 src/operators/rope.cc diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 0e1472bb..36486e36 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -47,6 +47,7 @@ class GraphHandlerObj { Tensor max(Tensor a, Tensor b, Tensor c); Tensor relu(Tensor x, Tensor y); + Tensor silu(Tensor x, Tensor y); Tensor gelu(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y); Tensor hardSigmoid(Tensor x, Tensor y); @@ -77,6 +78,7 @@ class GraphHandlerObj { Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul); + Tensor RoPE(Tensor pos, Tensor input, Tensor output); TensorVec split(Tensor input, std::optional outputs, int axis, std::variant> numOrRatio); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/core/op_type.h b/include/core/op_type.h index 1652a677..d0d0e92a 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -151,10 +151,12 @@ struct OpType { ReduceSum, // Reduce ReduceSumSquare, // Reduce Relu, // Unary + Silu, // Unary Reshape, Resize, ReverseSequence, RoiAlign, + RoPE, // Fusion Round, // Unary STFT, Scan, diff --git a/include/cuda/cuda_rope.h b/include/cuda/cuda_rope.h new file mode 100644 index 00000000..9766af5b --- /dev/null +++ b/include/cuda/cuda_rope.h @@ -0,0 +1,10 @@ +#pragma once + +#include "operators/rope.h" +#include "utils/small_array.h" + +namespace infini { + +void rope_kernel(int dType, int* pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride); + +}; // namespace infini diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index 49a589b3..2f7ffbba 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -5,6 +5,7 @@ namespace infini { template void softmax_kernel(T *input, T *output, size_t num); template void relu_kernel(T *input, T *output, size_t num); +template void silu_kernel(T *input, T *output, size_t num); template void sigmoid_kernel(T *input, T *output, size_t num); template void tanh_kernel(T *input, T *output, size_t num); template void abs_kernel(T *input, T *output, size_t num); diff --git a/include/operators/rope.h b/include/operators/rope.h new file mode 100644 index 00000000..2eb312fe --- /dev/null +++ b/include/operators/rope.h @@ -0,0 +1,21 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class RoPEObj : public OperatorObj { + public: + RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output); + OP_CLONE(RoPEObj); + optional> inferShape(const TensorVec &inputs) override; + + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + DataType getDType() const { return getInputs(1)->getDType(); } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/include/operators/unary.h b/include/operators/unary.h index c3e628d4..8da375de 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -258,6 +258,7 @@ class LogObj : public OperatorObj { }; DEFINE_UNARY_OBJ(Relu, OpType::Relu) +DEFINE_UNARY_OBJ(Silu, OpType::Silu) DEFINE_UNARY_OBJ(Gelu, OpType::Gelu) DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid) DEFINE_UNARY_OBJ(Tanh, OpType::Tanh) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 79abb7f4..58993519 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -438,6 +438,11 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "Silu": + tensors[node.output[0]] = self.handler.silu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "Gelu": tensors[node.output[0]] = self.handler.gelu( tensors[node.input[0]], @@ -669,6 +674,12 @@ class OnnxStub: tensors[node.input[5]], tensors.get(node.output[0]), ) + elif node.op_type == "RoPE": + tensors[node.output[0]]= self.handler.RoPE( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) elif node.op_type == "Split": split = ( _parse_data(data[node.input[1]]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index cd62ed32..e90fba4c 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -2,6 +2,7 @@ #include "operators/all_gather.h" #include "operators/all_reduce.h" #include "operators/attention_kvcache.h" +#include "operators/rope.h" #include "operators/batch_norm.h" #include "operators/broadcast.h" #include "operators/concat.h" @@ -180,7 +181,8 @@ DEFINE_ELEMENT_WISE_METHOD(max, Maximum) return g->addOp(std::move(x), y)->getOutput(); \ } \ } - + +DEFINE_UNARY_METHOD(silu, Silu) DEFINE_UNARY_METHOD(relu, Relu) DEFINE_UNARY_METHOD(gelu, Gelu) DEFINE_UNARY_METHOD(sigmoid, Sigmoid) @@ -345,6 +347,16 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, } } +Tensor GraphHandlerObj::RoPE(Tensor pos, Tensor input, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(pos), std::move(input), output); + return output; + } else { + return g->addOp(std::move(pos), std::move(input), output) + ->getOutput(); + } +} + TensorVec GraphHandlerObj::split(Tensor input, std::optional outputs, int axis, std::variant> numOrRatio) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index b565ad4d..41200933 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -515,6 +515,7 @@ void init_graph_builder(py::module &m) { .def("min", &Handler::min, policy::move) .def("max", &Handler::max, policy::move) .def("relu", &Handler::relu, policy::move) + .def("silu", &Handler::silu, policy::move) .def("gelu", &Handler::gelu, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move) .def("tanh", &Handler::tanh, policy::move) @@ -537,6 +538,7 @@ void init_graph_builder(py::module &m) { .def("unsqueeze", &Handler::unsqueeze, policy::move) .def("concat", &Handler::concat, policy::move) .def("attentionKVCache", &Handler::attentionKVCache, policy::move) + .def("RoPE", &Handler::RoPE, policy::move) .def("split", &Handler::split, policy::move) .def("gather", &Handler::gather, policy::move) .def("gatherElements", &Handler::gatherElements, policy::move) diff --git a/src/kernels/cuda/rope.cc b/src/kernels/cuda/rope.cc new file mode 100644 index 00000000..ca95c210 --- /dev/null +++ b/src/kernels/cuda/rope.cc @@ -0,0 +1,38 @@ +#include "operators/rope.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_rope.h" + +namespace infini { + +class RoPECuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + auto pos = op->getInputs(0); + auto input = op->getInputs(1); + auto output = op->getOutput(); + void *const inputData = input->getRawDataPtr(); + void *const outputData = output->getRawDataPtr(); + const auto &inputShape = input->getDims(); + int nDims = input->getDims().size(); + + int size = input->size(); + IT_ASSERT(nDims == 3 && pos->getDims().size() == 2); + IT_ASSERT(inputShape[1] == pos->getDims()[1]); + int dim_model = inputShape[2]; + int dim_head = dim_model / 32; + int hidden_stride = dim_model * inputShape[1]; + int pos_stride = inputShape[1]; + + const int dType = op->getDType().getIndex(); + rope_kernel(dType, pos->getRawDataPtr(), inputData, outputData, size, dim_model, dim_head, hidden_stride, pos_stride); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, + "RoPE_CUDA"); + + +} // namespace infini diff --git a/src/kernels/cuda/rope.cu b/src/kernels/cuda/rope.cu new file mode 100644 index 00000000..9b1bec54 --- /dev/null +++ b/src/kernels/cuda/rope.cu @@ -0,0 +1,91 @@ +#include "core/common.h" +#include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" +#include "utils/small_array.h" + +constexpr unsigned int num_threads() { return 32 * 4; } +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } + +// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1) +template +__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { + int batch_id = blockIdx.x; + int target_pos = pos[batch_id * pos_stride + blockIdx.y]; + int ith = blockIdx.z * blockDim.x + threadIdx.x; + int col = ith % dim_head; + int offset = batch_id * hidden_stride + blockIdx.y * dim_model; + + if (ith >= dim_model) + return; + int half_dim = dim_head / 2; + if (col < half_dim) { + float freq = target_pos * powf(10000, -float(col * 2) / dim_head); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + ((T *)out)[offset + ith] = + ((T *)in)[offset + ith] * T(cos_freq) - ((T *)in)[offset + ith + half_dim] * T(sin_freq); + } else { + float freq = target_pos * powf(10000, -float((col - half_dim) * 2) / dim_head); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + ((T *)out)[offset + ith] = + ((T *)in)[offset + ith] * T(cos_freq) + ((T *)in)[offset + ith - half_dim] * T(sin_freq); + } +} + + +#define CASE(T) \ + _rope_kernel::t><<>>( \ + pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + +namespace infini { +void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { + dim3 blocksize = dim3(1024,1,1); + dim3 gridsize = dim3(1, 1, 4); + SWITCH_DTYPE(dType) +} + +} // namespace infini diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index bb9691a7..acdb7579 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -157,6 +157,7 @@ class SoftmaxCudnn : public CudaKernelWithoutConfig { class ReluCudnn : public ActivationCudnn { cudnnActivationMode_t getOpType() const override { + return CUDNN_ACTIVATION_RELU; } }; @@ -182,6 +183,7 @@ REGISTER_KERNEL(Device::CUDA, OpType::Tanh, TanhCudnn, "Tanh_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Abs, UnaryCuda, "Abs_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, UnaryCuda, "Sqrt_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Silu, UnaryCuda, "Silu_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA"); diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index afd7f02a..f7e755df 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -103,6 +103,17 @@ __global__ void _gelu_kernel(T *input, T *output, size_t n) { output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f))); } } + +template +__global__ void _silu_kernel(T *input, T *output, size_t n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + float x = input[i]; + output[i] = x / (1.0 + expf(-x));; + } +} + template __global__ void _erf_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; @@ -190,6 +201,14 @@ template void gelu_kernel(T *input, T *output, size_t num) { int gridsize = (num + block_work_size() - 1) / block_work_size(); _gelu_kernel<<>>(input, output, num); } + +template void silu_kernel(T *input, T *output, size_t num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _silu_kernel<<>>(input, output, num); +} + template void erf_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); @@ -209,6 +228,7 @@ void unary_kernel(const Operator &_op) { void *const outputData = (op->getOutput()->getRawDataPtr()); size_t num = op->getOutput()->size(); + if (op->getOpType() == OpType::Softmax) { if (_op->getDType() == DataType::Float32) { softmax_kernel((float *)inputData, (float *)outputData, num); @@ -267,6 +287,12 @@ void unary_kernel(const Operator &_op) { } else { IT_TODO_HALT(); } + } else if (op->getOpType() == OpType::Silu) { + if (_op->getDType() == DataType::Float32) { + silu_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } } else if (op->getOpType() == OpType::Neg) { if (_op->getDType() == DataType::Float32) { neg_kernel((float *)inputData, (float *)outputData, num); diff --git a/src/operators/rope.cc b/src/operators/rope.cc new file mode 100644 index 00000000..76387bc4 --- /dev/null +++ b/src/operators/rope.cc @@ -0,0 +1,37 @@ +#include "operators/rope.h" + +namespace infini { +RoPEObj::RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output) + : OperatorObj(OpType::RoPE, {pos, input}, {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> RoPEObj::inferShape(const TensorVec &inputs) { + const auto A = inputs[1]; + auto input_dim = A->getDims(); + auto output_dim = input_dim; + return {{output_dim}}; +} + +std::string RoPEObj::toString() const { + std::ostringstream os; + os << type.toString() << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector RoPEObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector RoPEObj::getOpAttrVector() const { + return {type.underlying()}; +} + +}; // namespace infini From 956ce3745878521b386d0d514356da290a09cdd5 Mon Sep 17 00:00:00 2001 From: xiaonans Date: Tue, 30 Jan 2024 10:40:13 +0800 Subject: [PATCH 4/6] add unittest of silu kernel --- src/kernels/cpu/unary.cc | 8 ++++++++ test/kernels/cuda/test_cuda_unary.cc | 1 + 2 files changed, 9 insertions(+) diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 024d720a..9e7cead0 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -47,6 +47,10 @@ class NativeUnary : public CpuKernelWithoutConfig { return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); } + template static T siluCompute(T val) { + return val / (1 + pow(E_CONSTANT, -val)); + } + template static T erfCompute(T val) { return std::erf(val); } template static T aCosCompute(T val) { return std::acos(val); } @@ -84,6 +88,9 @@ class NativeUnary : public CpuKernelWithoutConfig { case OpType::Gelu: _doCompute = geluCompute; break; + case OpType::Silu: + _doCompute = siluCompute; + break; case OpType::Sigmoid: _doCompute = sigmoidCompute; break; @@ -289,6 +296,7 @@ class Log : public CpuKernelWithoutConfig { REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Silu, NativeUnary, "siluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary, "hardSigmoidNaive_CPU"); diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index fd407dfd..27ce90f1 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -70,6 +70,7 @@ void testCast(const std::function &generator, TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); From 9a3c0f11f63b336d7a7351d7eccbe016daca4fe5 Mon Sep 17 00:00:00 2001 From: xiaonans Date: Tue, 30 Jan 2024 15:27:04 +0800 Subject: [PATCH 5/6] add test for rotary embedding cuda kernel --- include/core/op_type.h | 2 +- include/cuda/cuda_rope.h | 4 +++- src/core/graph_handler.cc | 2 +- src/kernels/cuda/rope.cc | 9 ++++---- src/operators/rope.cc | 4 +--- test/kernels/cuda/test_cuda_rope.cc | 36 +++++++++++++++++++++++++++++ 6 files changed, 46 insertions(+), 11 deletions(-) create mode 100644 test/kernels/cuda/test_cuda_rope.cc diff --git a/include/core/op_type.h b/include/core/op_type.h index d0d0e92a..dbcfbdb9 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -156,7 +156,7 @@ struct OpType { Resize, ReverseSequence, RoiAlign, - RoPE, // Fusion + RoPE, // Fusion Round, // Unary STFT, Scan, diff --git a/include/cuda/cuda_rope.h b/include/cuda/cuda_rope.h index 9766af5b..ca9d5c54 100644 --- a/include/cuda/cuda_rope.h +++ b/include/cuda/cuda_rope.h @@ -5,6 +5,8 @@ namespace infini { -void rope_kernel(int dType, int* pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride); +void rope_kernel(int dType, int *pos, void *input, void *output, int size, + int dim_model, int dim_head, int hidden_stride, + int pos_stride); }; // namespace infini diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 48c31212..0821121d 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -2,7 +2,6 @@ #include "operators/all_gather.h" #include "operators/all_reduce.h" #include "operators/attention_kvcache.h" -#include "operators/rope.h" #include "operators/batch_norm.h" #include "operators/broadcast.h" #include "operators/concat.h" @@ -19,6 +18,7 @@ #include "operators/reduce.h" #include "operators/reshape.h" #include "operators/resize.h" +#include "operators/rope.h" #include "operators/send.h" #include "operators/slice.h" #include "operators/softmax.h" diff --git a/src/kernels/cuda/rope.cc b/src/kernels/cuda/rope.cc index ca95c210..1ec5cca2 100644 --- a/src/kernels/cuda/rope.cc +++ b/src/kernels/cuda/rope.cc @@ -1,7 +1,7 @@ #include "operators/rope.h" #include "cuda/cuda_kernel_wihtout_config.h" -#include "cuda/cuda_runtime.h" #include "cuda/cuda_rope.h" +#include "cuda/cuda_runtime.h" namespace infini { @@ -27,12 +27,11 @@ class RoPECuda : public CudaKernelWithoutConfig { int pos_stride = inputShape[1]; const int dType = op->getDType().getIndex(); - rope_kernel(dType, pos->getRawDataPtr(), inputData, outputData, size, dim_model, dim_head, hidden_stride, pos_stride); + rope_kernel(dType, pos->getRawDataPtr(), inputData, outputData, + size, dim_model, dim_head, hidden_stride, pos_stride); } }; -REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, - "RoPE_CUDA"); - +REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, "RoPE_CUDA"); } // namespace infini diff --git a/src/operators/rope.cc b/src/operators/rope.cc index 76387bc4..25dfa202 100644 --- a/src/operators/rope.cc +++ b/src/operators/rope.cc @@ -30,8 +30,6 @@ vector RoPEObj::getWorkloadVector() const { return ret; } -vector RoPEObj::getOpAttrVector() const { - return {type.underlying()}; -} +vector RoPEObj::getOpAttrVector() const { return {type.underlying()}; } }; // namespace infini diff --git a/test/kernels/cuda/test_cuda_rope.cc b/test/kernels/cuda/test_cuda_rope.cc new file mode 100644 index 00000000..8d88bf8e --- /dev/null +++ b/test/kernels/cuda/test_cuda_rope.cc @@ -0,0 +1,36 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/rope.h" + +#include "test.h" + +namespace infini { +TEST(RoPE, Cuda) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + + Graph gCpu = make_ref(runtime); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + auto input = gCuda->addTensor({1, 1, 32}, DataType::Float32); + auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); + auto output = gCuda->addTensor({1, 1, 32}, DataType::Float32); + + auto op = gCuda->addOpWithOutputs(position_id_d, input, output); + gCuda->dataMalloc(); + + input->setData(OneGenerator()); + position_id_d->setData(OneGenerator()); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]); + EXPECT_TRUE(oCpu->equalData(vector{ + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773})); +} +} // namespace infini From ae9f61de5ae1fc26c9257b1e4ca71eca4528ca1d Mon Sep 17 00:00:00 2001 From: xiaonans Date: Sun, 4 Feb 2024 10:40:25 +0800 Subject: [PATCH 6/6] add comment for rope operator --- include/operators/rope.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/operators/rope.h b/include/operators/rope.h index 2eb312fe..b21adb24 100644 --- a/include/operators/rope.h +++ b/include/operators/rope.h @@ -4,6 +4,14 @@ namespace infini { class RoPEObj : public OperatorObj { public: + /** + * @brief Construct a new RotaryEmbedding object. + * + * @param graph The computation graph that this operator belongs to. + * @param pos The positon id of the query. + * @param input The input tensor. + * @param output The output tensor. + */ RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output); OP_CLONE(RoPEObj); optional> inferShape(const TensorVec &inputs) override;