From 9a3c0f11f63b336d7a7351d7eccbe016daca4fe5 Mon Sep 17 00:00:00 2001 From: xiaonans Date: Tue, 30 Jan 2024 15:27:04 +0800 Subject: [PATCH] add test for rotary embedding cuda kernel --- include/core/op_type.h | 2 +- include/cuda/cuda_rope.h | 4 +++- src/core/graph_handler.cc | 2 +- src/kernels/cuda/rope.cc | 9 ++++---- src/operators/rope.cc | 4 +--- test/kernels/cuda/test_cuda_rope.cc | 36 +++++++++++++++++++++++++++++ 6 files changed, 46 insertions(+), 11 deletions(-) create mode 100644 test/kernels/cuda/test_cuda_rope.cc diff --git a/include/core/op_type.h b/include/core/op_type.h index d0d0e92a..dbcfbdb9 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -156,7 +156,7 @@ struct OpType { Resize, ReverseSequence, RoiAlign, - RoPE, // Fusion + RoPE, // Fusion Round, // Unary STFT, Scan, diff --git a/include/cuda/cuda_rope.h b/include/cuda/cuda_rope.h index 9766af5b..ca9d5c54 100644 --- a/include/cuda/cuda_rope.h +++ b/include/cuda/cuda_rope.h @@ -5,6 +5,8 @@ namespace infini { -void rope_kernel(int dType, int* pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride); +void rope_kernel(int dType, int *pos, void *input, void *output, int size, + int dim_model, int dim_head, int hidden_stride, + int pos_stride); }; // namespace infini diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 48c31212..0821121d 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -2,7 +2,6 @@ #include "operators/all_gather.h" #include "operators/all_reduce.h" #include "operators/attention_kvcache.h" -#include "operators/rope.h" #include "operators/batch_norm.h" #include "operators/broadcast.h" #include "operators/concat.h" @@ -19,6 +18,7 @@ #include "operators/reduce.h" #include "operators/reshape.h" #include "operators/resize.h" +#include "operators/rope.h" #include "operators/send.h" #include "operators/slice.h" #include "operators/softmax.h" diff --git a/src/kernels/cuda/rope.cc b/src/kernels/cuda/rope.cc index ca95c210..1ec5cca2 100644 --- a/src/kernels/cuda/rope.cc +++ b/src/kernels/cuda/rope.cc @@ -1,7 +1,7 @@ #include "operators/rope.h" #include "cuda/cuda_kernel_wihtout_config.h" -#include "cuda/cuda_runtime.h" #include "cuda/cuda_rope.h" +#include "cuda/cuda_runtime.h" namespace infini { @@ -27,12 +27,11 @@ class RoPECuda : public CudaKernelWithoutConfig { int pos_stride = inputShape[1]; const int dType = op->getDType().getIndex(); - rope_kernel(dType, pos->getRawDataPtr(), inputData, outputData, size, dim_model, dim_head, hidden_stride, pos_stride); + rope_kernel(dType, pos->getRawDataPtr(), inputData, outputData, + size, dim_model, dim_head, hidden_stride, pos_stride); } }; -REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, - "RoPE_CUDA"); - +REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, "RoPE_CUDA"); } // namespace infini diff --git a/src/operators/rope.cc b/src/operators/rope.cc index 76387bc4..25dfa202 100644 --- a/src/operators/rope.cc +++ b/src/operators/rope.cc @@ -30,8 +30,6 @@ vector RoPEObj::getWorkloadVector() const { return ret; } -vector RoPEObj::getOpAttrVector() const { - return {type.underlying()}; -} +vector RoPEObj::getOpAttrVector() const { return {type.underlying()}; } }; // namespace infini diff --git a/test/kernels/cuda/test_cuda_rope.cc b/test/kernels/cuda/test_cuda_rope.cc new file mode 100644 index 00000000..8d88bf8e --- /dev/null +++ b/test/kernels/cuda/test_cuda_rope.cc @@ -0,0 +1,36 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/rope.h" + +#include "test.h" + +namespace infini { +TEST(RoPE, Cuda) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + + Graph gCpu = make_ref(runtime); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + auto input = gCuda->addTensor({1, 1, 32}, DataType::Float32); + auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); + auto output = gCuda->addTensor({1, 1, 32}, DataType::Float32); + + auto op = gCuda->addOpWithOutputs(position_id_d, input, output); + gCuda->dataMalloc(); + + input->setData(OneGenerator()); + position_id_d->setData(OneGenerator()); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]); + EXPECT_TRUE(oCpu->equalData(vector{ + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773})); +} +} // namespace infini