add test for rotary embedding cuda kernel

This commit is contained in:
xiaonans 2024-01-30 15:27:04 +08:00
parent 956ce37458
commit 9a3c0f11f6
6 changed files with 46 additions and 11 deletions

View File

@ -156,7 +156,7 @@ struct OpType {
Resize, Resize,
ReverseSequence, ReverseSequence,
RoiAlign, RoiAlign,
RoPE, // Fusion RoPE, // Fusion
Round, // Unary Round, // Unary
STFT, STFT,
Scan, Scan,

View File

@ -5,6 +5,8 @@
namespace infini { namespace infini {
void rope_kernel(int dType, int* pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride); void rope_kernel(int dType, int *pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride,
int pos_stride);
}; // namespace infini }; // namespace infini

View File

@ -2,7 +2,6 @@
#include "operators/all_gather.h" #include "operators/all_gather.h"
#include "operators/all_reduce.h" #include "operators/all_reduce.h"
#include "operators/attention_kvcache.h" #include "operators/attention_kvcache.h"
#include "operators/rope.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/broadcast.h" #include "operators/broadcast.h"
#include "operators/concat.h" #include "operators/concat.h"
@ -19,6 +18,7 @@
#include "operators/reduce.h" #include "operators/reduce.h"
#include "operators/reshape.h" #include "operators/reshape.h"
#include "operators/resize.h" #include "operators/resize.h"
#include "operators/rope.h"
#include "operators/send.h" #include "operators/send.h"
#include "operators/slice.h" #include "operators/slice.h"
#include "operators/softmax.h" #include "operators/softmax.h"

View File

@ -1,7 +1,7 @@
#include "operators/rope.h" #include "operators/rope.h"
#include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_rope.h" #include "cuda/cuda_rope.h"
#include "cuda/cuda_runtime.h"
namespace infini { namespace infini {
@ -27,12 +27,11 @@ class RoPECuda : public CudaKernelWithoutConfig {
int pos_stride = inputShape[1]; int pos_stride = inputShape[1];
const int dType = op->getDType().getIndex(); const int dType = op->getDType().getIndex();
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData, size, dim_model, dim_head, hidden_stride, pos_stride); rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
size, dim_model, dim_head, hidden_stride, pos_stride);
} }
}; };
REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, "RoPE_CUDA");
"RoPE_CUDA");
} // namespace infini } // namespace infini

View File

@ -30,8 +30,6 @@ vector<int> RoPEObj::getWorkloadVector() const {
return ret; return ret;
} }
vector<int> RoPEObj::getOpAttrVector() const { vector<int> RoPEObj::getOpAttrVector() const { return {type.underlying()}; }
return {type.underlying()};
}
}; // namespace infini }; // namespace infini

View File

@ -0,0 +1,36 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/rope.h"
#include "test.h"
namespace infini {
TEST(RoPE, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto input = gCuda->addTensor({1, 1, 32}, DataType::Float32);
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
auto output = gCuda->addTensor({1, 1, 32}, DataType::Float32);
auto op = gCuda->addOpWithOutputs<RoPEObj>(position_id_d, input, output);
gCuda->dataMalloc();
input->setData(OneGenerator());
position_id_d->setData(OneGenerator());
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
EXPECT_TRUE(oCpu->equalData(vector<float>{
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
1.381773, 1.381773, 1.381773, 1.381773}));
}
} // namespace infini