diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index ce455d62..30095891 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -36,6 +36,7 @@ class GraphHandlerObj { float momentum, float eps, bool training); Tensor layerNormalization(Tensor input, Tensor scale, Tensor output, Tensor bias, float eps, int axis, int stash_type); + Tensor rmsNorm(Tensor input, Tensor weight, Tensor output); Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode); diff --git a/include/core/op_type.h b/include/core/op_type.h index dbcfbdb9..3b4045c4 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -158,6 +158,7 @@ struct OpType { RoiAlign, RoPE, // Fusion Round, // Unary + RMSNorm, // Fusion STFT, Scan, Scatter, diff --git a/include/cuda/cuda_rmsnorm.h b/include/cuda/cuda_rmsnorm.h new file mode 100644 index 00000000..024cb75a --- /dev/null +++ b/include/cuda/cuda_rmsnorm.h @@ -0,0 +1,10 @@ +#pragma once + +#include "operators/rms_norm.h" + +namespace infini { + +void rmsnorm_kernel(int dType, void *input, void *weight, void *output, + int num_tokens, int hidden_size); + +}; // namespace infini diff --git a/include/operators/rms_norm.h b/include/operators/rms_norm.h new file mode 100644 index 00000000..10c385d2 --- /dev/null +++ b/include/operators/rms_norm.h @@ -0,0 +1,34 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief Fused RMSNorm Operator + * + */ +class RMSNormObj : public OperatorObj { + int dim; + + public: + /** + * @brief Construct a new RMSNorm object. + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor. + * @param output The output tensor. + */ + RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output); + OP_CLONE(RMSNormObj); + + optional> inferShape(const TensorVec &inputs) override; + + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 085ff0c3..768bb5a3 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -277,6 +277,12 @@ class OnnxStub: axis, stash_type, ) + elif node.op_type == "RMSNorm": + tensors[node.output[0]] = self.handler.RMSNorm( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) elif node.op_type == "MaxPool": attributes = _parse_attribute( node, diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 0821121d..596910f1 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -18,6 +18,7 @@ #include "operators/reduce.h" #include "operators/reshape.h" #include "operators/resize.h" +#include "operators/rms_norm.h" #include "operators/rope.h" #include "operators/send.h" #include "operators/slice.h" @@ -122,6 +123,16 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale, } } +Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(input), std::move(weight), output); + return output; + } else { + return g->addOp(std::move(input), std::move(weight), output) + ->getOutput(); + } +} + Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 9dc43510..a9cee80a 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -504,6 +504,7 @@ void init_graph_builder(py::module &m) { .def("matmul", &Handler::matmul, policy::move) .def("batchNormalization", &Handler::batchNormalization, policy::move) .def("layerNormalization", &Handler::layerNormalization, policy::move) + .def("RMSNorm", &Handler::rmsNorm, policy::move) .def("maxPool", &Handler::maxPool, policy::move) .def("avgPool", &Handler::avgPool, policy::move) .def("add", &Handler::add, policy::move) diff --git a/src/kernels/cuda/rms_norm.cc b/src/kernels/cuda/rms_norm.cc new file mode 100644 index 00000000..1bd6dfda --- /dev/null +++ b/src/kernels/cuda/rms_norm.cc @@ -0,0 +1,34 @@ +#include "operators/rms_norm.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_rmsnorm.h" +#include "cuda/cuda_runtime.h" + +namespace infini { + +class RMSNormCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + auto input = op->getInputs(0); + auto weight = op->getInputs(1); + auto output = op->getOutput(); + void *const inputData = input->getRawDataPtr(); + void *const weightData = weight->getRawDataPtr(); + void *const outputData = output->getRawDataPtr(); + const auto &inputShape = input->getDims(); + int nDims = input->getDims().size(); + + int hidden_size = inputShape[nDims - 1]; + int num_tokens = input->size() / hidden_size; + IT_ASSERT(hidden_size == (int)weight->size()); + + const int dType = op->getDType().getIndex(); + rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens, + hidden_size); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA"); + +} // namespace infini diff --git a/src/kernels/cuda/rms_norm.cu b/src/kernels/cuda/rms_norm.cu new file mode 100644 index 00000000..9eca738f --- /dev/null +++ b/src/kernels/cuda/rms_norm.cu @@ -0,0 +1,112 @@ +#include "core/common.h" +#include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" +#include "utils/small_array.h" + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(uint32_t(-1), val, mask); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +template +__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ + const float x = ((float*) in)[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + variance = blockReduceSum(variance); + if(threadIdx.x == 0){ + s_variance = rsqrtf(variance / hidden_size + 0.00001f); + } + __syncthreads(); + + for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ + float x = ((float*) in)[blockIdx.x * hidden_size + idx]; + ((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx]; + } +} + + +#define CASE(T) \ + _rmsnorm_kernel::t> \ + <<>> \ + (input, weight, output, num_tokens, hidden_size); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + +namespace infini { +void rmsnorm_kernel(int dType, void *input, void *weight, void *output, + int num_tokens, int hidden_size) { + dim3 blocksize = dim3(std::min(hidden_size, 1024)); + dim3 gridsize = dim3(num_tokens); + SWITCH_DTYPE(dType) +} + +} // namespace infini diff --git a/src/operators/rms_norm.cc b/src/operators/rms_norm.cc new file mode 100644 index 00000000..f71612a6 --- /dev/null +++ b/src/operators/rms_norm.cc @@ -0,0 +1,36 @@ +#include "operators/rms_norm.h" + +namespace infini { +RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, + Tensor output) + : OperatorObj(OpType::RMSNorm, {input, weight}, {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> RMSNormObj::inferShape(const TensorVec &inputs) { + const auto A = inputs[0]; + auto input_dim = A->getDims(); + auto output_dim = input_dim; + return {{output_dim}}; +} + +std::string RMSNormObj::toString() const { + std::ostringstream os; + os << type.toString() << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector RMSNormObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector RMSNormObj::getOpAttrVector() const { return {type.underlying()}; } + +}; // namespace infini