diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 0e1472bb..36486e36 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -47,6 +47,7 @@ class GraphHandlerObj { Tensor max(Tensor a, Tensor b, Tensor c); Tensor relu(Tensor x, Tensor y); + Tensor silu(Tensor x, Tensor y); Tensor gelu(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y); Tensor hardSigmoid(Tensor x, Tensor y); @@ -77,6 +78,7 @@ class GraphHandlerObj { Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor input_q, Tensor input_k, Tensor input_v, Tensor position_id, Tensor output_matmul); + Tensor RoPE(Tensor pos, Tensor input, Tensor output); TensorVec split(Tensor input, std::optional outputs, int axis, std::variant> numOrRatio); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); diff --git a/include/core/op_type.h b/include/core/op_type.h index 1652a677..dbcfbdb9 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -151,10 +151,12 @@ struct OpType { ReduceSum, // Reduce ReduceSumSquare, // Reduce Relu, // Unary + Silu, // Unary Reshape, Resize, ReverseSequence, RoiAlign, + RoPE, // Fusion Round, // Unary STFT, Scan, diff --git a/include/cuda/cuda_rope.h b/include/cuda/cuda_rope.h new file mode 100644 index 00000000..ca9d5c54 --- /dev/null +++ b/include/cuda/cuda_rope.h @@ -0,0 +1,12 @@ +#pragma once + +#include "operators/rope.h" +#include "utils/small_array.h" + +namespace infini { + +void rope_kernel(int dType, int *pos, void *input, void *output, int size, + int dim_model, int dim_head, int hidden_stride, + int pos_stride); + +}; // namespace infini diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index 49a589b3..2f7ffbba 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -5,6 +5,7 @@ namespace infini { template void softmax_kernel(T *input, T *output, size_t num); template void relu_kernel(T *input, T *output, size_t num); +template void silu_kernel(T *input, T *output, size_t num); template void sigmoid_kernel(T *input, T *output, size_t num); template void tanh_kernel(T *input, T *output, size_t num); template void abs_kernel(T *input, T *output, size_t num); diff --git a/include/operators/rope.h b/include/operators/rope.h new file mode 100644 index 00000000..b21adb24 --- /dev/null +++ b/include/operators/rope.h @@ -0,0 +1,29 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class RoPEObj : public OperatorObj { + public: + /** + * @brief Construct a new RotaryEmbedding object. + * + * @param graph The computation graph that this operator belongs to. + * @param pos The positon id of the query. + * @param input The input tensor. + * @param output The output tensor. + */ + RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output); + OP_CLONE(RoPEObj); + optional> inferShape(const TensorVec &inputs) override; + + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + DataType getDType() const { return getInputs(1)->getDType(); } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/include/operators/unary.h b/include/operators/unary.h index c3e628d4..8da375de 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -258,6 +258,7 @@ class LogObj : public OperatorObj { }; DEFINE_UNARY_OBJ(Relu, OpType::Relu) +DEFINE_UNARY_OBJ(Silu, OpType::Silu) DEFINE_UNARY_OBJ(Gelu, OpType::Gelu) DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid) DEFINE_UNARY_OBJ(Tanh, OpType::Tanh) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 79abb7f4..58993519 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -438,6 +438,11 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "Silu": + tensors[node.output[0]] = self.handler.silu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "Gelu": tensors[node.output[0]] = self.handler.gelu( tensors[node.input[0]], @@ -669,6 +674,12 @@ class OnnxStub: tensors[node.input[5]], tensors.get(node.output[0]), ) + elif node.op_type == "RoPE": + tensors[node.output[0]]= self.handler.RoPE( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) elif node.op_type == "Split": split = ( _parse_data(data[node.input[1]]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index cd62ed32..0821121d 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -18,6 +18,7 @@ #include "operators/reduce.h" #include "operators/reshape.h" #include "operators/resize.h" +#include "operators/rope.h" #include "operators/send.h" #include "operators/slice.h" #include "operators/softmax.h" @@ -181,6 +182,7 @@ DEFINE_ELEMENT_WISE_METHOD(max, Maximum) } \ } +DEFINE_UNARY_METHOD(silu, Silu) DEFINE_UNARY_METHOD(relu, Relu) DEFINE_UNARY_METHOD(gelu, Gelu) DEFINE_UNARY_METHOD(sigmoid, Sigmoid) @@ -345,6 +347,16 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache, } } +Tensor GraphHandlerObj::RoPE(Tensor pos, Tensor input, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(pos), std::move(input), output); + return output; + } else { + return g->addOp(std::move(pos), std::move(input), output) + ->getOutput(); + } +} + TensorVec GraphHandlerObj::split(Tensor input, std::optional outputs, int axis, std::variant> numOrRatio) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index b565ad4d..41200933 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -515,6 +515,7 @@ void init_graph_builder(py::module &m) { .def("min", &Handler::min, policy::move) .def("max", &Handler::max, policy::move) .def("relu", &Handler::relu, policy::move) + .def("silu", &Handler::silu, policy::move) .def("gelu", &Handler::gelu, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move) .def("tanh", &Handler::tanh, policy::move) @@ -537,6 +538,7 @@ void init_graph_builder(py::module &m) { .def("unsqueeze", &Handler::unsqueeze, policy::move) .def("concat", &Handler::concat, policy::move) .def("attentionKVCache", &Handler::attentionKVCache, policy::move) + .def("RoPE", &Handler::RoPE, policy::move) .def("split", &Handler::split, policy::move) .def("gather", &Handler::gather, policy::move) .def("gatherElements", &Handler::gatherElements, policy::move) diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 024d720a..9e7cead0 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -47,6 +47,10 @@ class NativeUnary : public CpuKernelWithoutConfig { return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); } + template static T siluCompute(T val) { + return val / (1 + pow(E_CONSTANT, -val)); + } + template static T erfCompute(T val) { return std::erf(val); } template static T aCosCompute(T val) { return std::acos(val); } @@ -84,6 +88,9 @@ class NativeUnary : public CpuKernelWithoutConfig { case OpType::Gelu: _doCompute = geluCompute; break; + case OpType::Silu: + _doCompute = siluCompute; + break; case OpType::Sigmoid: _doCompute = sigmoidCompute; break; @@ -289,6 +296,7 @@ class Log : public CpuKernelWithoutConfig { REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Silu, NativeUnary, "siluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary, "hardSigmoidNaive_CPU"); diff --git a/src/kernels/cuda/rope.cc b/src/kernels/cuda/rope.cc new file mode 100644 index 00000000..1ec5cca2 --- /dev/null +++ b/src/kernels/cuda/rope.cc @@ -0,0 +1,37 @@ +#include "operators/rope.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_rope.h" +#include "cuda/cuda_runtime.h" + +namespace infini { + +class RoPECuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + auto pos = op->getInputs(0); + auto input = op->getInputs(1); + auto output = op->getOutput(); + void *const inputData = input->getRawDataPtr(); + void *const outputData = output->getRawDataPtr(); + const auto &inputShape = input->getDims(); + int nDims = input->getDims().size(); + + int size = input->size(); + IT_ASSERT(nDims == 3 && pos->getDims().size() == 2); + IT_ASSERT(inputShape[1] == pos->getDims()[1]); + int dim_model = inputShape[2]; + int dim_head = dim_model / 32; + int hidden_stride = dim_model * inputShape[1]; + int pos_stride = inputShape[1]; + + const int dType = op->getDType().getIndex(); + rope_kernel(dType, pos->getRawDataPtr(), inputData, outputData, + size, dim_model, dim_head, hidden_stride, pos_stride); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, "RoPE_CUDA"); + +} // namespace infini diff --git a/src/kernels/cuda/rope.cu b/src/kernels/cuda/rope.cu new file mode 100644 index 00000000..9b1bec54 --- /dev/null +++ b/src/kernels/cuda/rope.cu @@ -0,0 +1,91 @@ +#include "core/common.h" +#include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" +#include "utils/small_array.h" + +constexpr unsigned int num_threads() { return 32 * 4; } +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } + +// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1) +template +__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { + int batch_id = blockIdx.x; + int target_pos = pos[batch_id * pos_stride + blockIdx.y]; + int ith = blockIdx.z * blockDim.x + threadIdx.x; + int col = ith % dim_head; + int offset = batch_id * hidden_stride + blockIdx.y * dim_model; + + if (ith >= dim_model) + return; + int half_dim = dim_head / 2; + if (col < half_dim) { + float freq = target_pos * powf(10000, -float(col * 2) / dim_head); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + ((T *)out)[offset + ith] = + ((T *)in)[offset + ith] * T(cos_freq) - ((T *)in)[offset + ith + half_dim] * T(sin_freq); + } else { + float freq = target_pos * powf(10000, -float((col - half_dim) * 2) / dim_head); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + ((T *)out)[offset + ith] = + ((T *)in)[offset + ith] * T(cos_freq) + ((T *)in)[offset + ith - half_dim] * T(sin_freq); + } +} + + +#define CASE(T) \ + _rope_kernel::t><<>>( \ + pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + +namespace infini { +void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { + dim3 blocksize = dim3(1024,1,1); + dim3 gridsize = dim3(1, 1, 4); + SWITCH_DTYPE(dType) +} + +} // namespace infini diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index bb9691a7..3bbd5bba 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -182,6 +182,7 @@ REGISTER_KERNEL(Device::CUDA, OpType::Tanh, TanhCudnn, "Tanh_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Abs, UnaryCuda, "Abs_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, UnaryCuda, "Sqrt_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Silu, UnaryCuda, "Silu_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA"); diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index afd7f02a..98f1ed9f 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -103,6 +103,17 @@ __global__ void _gelu_kernel(T *input, T *output, size_t n) { output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f))); } } + +template +__global__ void _silu_kernel(T *input, T *output, size_t n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + float x = input[i]; + output[i] = x / (1.0 + expf(-x));; + } +} + template __global__ void _erf_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; @@ -190,6 +201,14 @@ template void gelu_kernel(T *input, T *output, size_t num) { int gridsize = (num + block_work_size() - 1) / block_work_size(); _gelu_kernel<<>>(input, output, num); } + +template void silu_kernel(T *input, T *output, size_t num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _silu_kernel<<>>(input, output, num); +} + template void erf_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); @@ -267,6 +286,12 @@ void unary_kernel(const Operator &_op) { } else { IT_TODO_HALT(); } + } else if (op->getOpType() == OpType::Silu) { + if (_op->getDType() == DataType::Float32) { + silu_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } } else if (op->getOpType() == OpType::Neg) { if (_op->getDType() == DataType::Float32) { neg_kernel((float *)inputData, (float *)outputData, num); diff --git a/src/operators/rope.cc b/src/operators/rope.cc new file mode 100644 index 00000000..25dfa202 --- /dev/null +++ b/src/operators/rope.cc @@ -0,0 +1,35 @@ +#include "operators/rope.h" + +namespace infini { +RoPEObj::RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output) + : OperatorObj(OpType::RoPE, {pos, input}, {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> RoPEObj::inferShape(const TensorVec &inputs) { + const auto A = inputs[1]; + auto input_dim = A->getDims(); + auto output_dim = input_dim; + return {{output_dim}}; +} + +std::string RoPEObj::toString() const { + std::ostringstream os; + os << type.toString() << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector RoPEObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector RoPEObj::getOpAttrVector() const { return {type.underlying()}; } + +}; // namespace infini diff --git a/test/kernels/cuda/test_cuda_rope.cc b/test/kernels/cuda/test_cuda_rope.cc new file mode 100644 index 00000000..8d88bf8e --- /dev/null +++ b/test/kernels/cuda/test_cuda_rope.cc @@ -0,0 +1,36 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/rope.h" + +#include "test.h" + +namespace infini { +TEST(RoPE, Cuda) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + + Graph gCpu = make_ref(runtime); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + auto input = gCuda->addTensor({1, 1, 32}, DataType::Float32); + auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32); + auto output = gCuda->addTensor({1, 1, 32}, DataType::Float32); + + auto op = gCuda->addOpWithOutputs(position_id_d, input, output); + gCuda->dataMalloc(); + + input->setData(OneGenerator()); + position_id_d->setData(OneGenerator()); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]); + EXPECT_TRUE(oCpu->equalData(vector{ + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, + 1.381773, 1.381773, 1.381773, 1.381773})); +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index fd407dfd..27ce90f1 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -70,6 +70,7 @@ void testCast(const std::function &generator, TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3});