forked from jiuyuan/InfiniTensor
add test for rotary embedding cuda kernel
This commit is contained in:
parent
956ce37458
commit
9a3c0f11f6
|
@ -156,7 +156,7 @@ struct OpType {
|
|||
Resize,
|
||||
ReverseSequence,
|
||||
RoiAlign,
|
||||
RoPE, // Fusion
|
||||
RoPE, // Fusion
|
||||
Round, // Unary
|
||||
STFT,
|
||||
Scan,
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
void rope_kernel(int dType, int* pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride);
|
||||
void rope_kernel(int dType, int *pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride,
|
||||
int pos_stride);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
#include "operators/all_gather.h"
|
||||
#include "operators/all_reduce.h"
|
||||
#include "operators/attention_kvcache.h"
|
||||
#include "operators/rope.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "operators/broadcast.h"
|
||||
#include "operators/concat.h"
|
||||
|
@ -19,6 +18,7 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/resize.h"
|
||||
#include "operators/rope.h"
|
||||
#include "operators/send.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/softmax.h"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "operators/rope.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_rope.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -27,12 +27,11 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
|||
int pos_stride = inputShape[1];
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData, size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
|
||||
size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda,
|
||||
"RoPE_CUDA");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, "RoPE_CUDA");
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -30,8 +30,6 @@ vector<int> RoPEObj::getWorkloadVector() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
vector<int> RoPEObj::getOpAttrVector() const {
|
||||
return {type.underlying()};
|
||||
}
|
||||
vector<int> RoPEObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/rope.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(RoPE, Cuda) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto input = gCuda->addTensor({1, 1, 32}, DataType::Float32);
|
||||
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
||||
auto output = gCuda->addTensor({1, 1, 32}, DataType::Float32);
|
||||
|
||||
auto op = gCuda->addOpWithOutputs<RoPEObj>(position_id_d, input, output);
|
||||
gCuda->dataMalloc();
|
||||
|
||||
input->setData(OneGenerator());
|
||||
position_id_d->setData(OneGenerator());
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773}));
|
||||
}
|
||||
} // namespace infini
|
Loading…
Reference in New Issue