From a98573990b032012aa2c13efcd2df3becd72fb91 Mon Sep 17 00:00:00 2001 From: xiaonans <51065160+xiaonans@users.noreply.github.com> Date: Mon, 1 Apr 2024 08:46:05 +0800 Subject: [PATCH] Accelerate llama (#219) * [feature] add cudagraph support * modify code to pass the cuda_all_reduce test * modify rope op * support rmsnorm * add fp16 support to silu cuda op * fix bugs in rmsnorm op * uncomment simplify in onnx.py --------- Co-authored-by: Haojie Wang --- include/core/graph_handler.h | 1 + include/core/op_type.h | 5 +- include/cuda/cuda_rmsnorm.h | 10 ++ include/operators/rms_norm.h | 34 +++++++ pyinfinitensor/src/pyinfinitensor/onnx.py | 6 ++ src/core/graph_handler.cc | 12 +++ src/ffi/ffi_infinitensor.cc | 1 + src/kernels/cuda/rms_norm.cc | 34 +++++++ src/kernels/cuda/rms_norm.cu | 112 ++++++++++++++++++++++ src/kernels/cuda/rope.cc | 2 +- src/kernels/cuda/rope.cu | 9 +- src/kernels/cuda/unary.cu | 2 + src/operators/rms_norm.cc | 36 +++++++ 13 files changed, 254 insertions(+), 10 deletions(-) create mode 100644 include/cuda/cuda_rmsnorm.h create mode 100644 include/operators/rms_norm.h create mode 100644 src/kernels/cuda/rms_norm.cc create mode 100644 src/kernels/cuda/rms_norm.cu create mode 100644 src/operators/rms_norm.cc diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 92a17048..fe3e5759 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -37,6 +37,7 @@ class GraphHandlerObj { float momentum, float eps, bool training); Tensor layerNormalization(Tensor input, Tensor scale, Tensor output, Tensor bias, float eps, int axis, int stash_type); + Tensor rmsNorm(Tensor input, Tensor weight, Tensor output); Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode); diff --git a/include/core/op_type.h b/include/core/op_type.h index dbcfbdb9..e624877b 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -156,8 +156,9 @@ struct OpType { Resize, ReverseSequence, RoiAlign, - RoPE, // Fusion - Round, // Unary + RoPE, // Fusion + Round, // Unary + RMSNorm, // Fusion STFT, Scan, Scatter, diff --git a/include/cuda/cuda_rmsnorm.h b/include/cuda/cuda_rmsnorm.h new file mode 100644 index 00000000..024cb75a --- /dev/null +++ b/include/cuda/cuda_rmsnorm.h @@ -0,0 +1,10 @@ +#pragma once + +#include "operators/rms_norm.h" + +namespace infini { + +void rmsnorm_kernel(int dType, void *input, void *weight, void *output, + int num_tokens, int hidden_size); + +}; // namespace infini diff --git a/include/operators/rms_norm.h b/include/operators/rms_norm.h new file mode 100644 index 00000000..10c385d2 --- /dev/null +++ b/include/operators/rms_norm.h @@ -0,0 +1,34 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief Fused RMSNorm Operator + * + */ +class RMSNormObj : public OperatorObj { + int dim; + + public: + /** + * @brief Construct a new RMSNorm object. + * + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor. + * @param output The output tensor. + */ + RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output); + OP_CLONE(RMSNormObj); + + optional> inferShape(const TensorVec &inputs) override; + + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + int getDim() const { return dim; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 522a4813..f47dcd0a 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -285,6 +285,12 @@ class OnnxStub: axis, stash_type, ) + elif node.op_type == "RMSNorm": + tensors[node.output[0]] = self.handler.RMSNorm( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) elif node.op_type == "MaxPool": attributes = _parse_attribute( node, diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 4e9fa0d3..e6bfffdd 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -18,6 +18,7 @@ #include "operators/reduce.h" #include "operators/reshape.h" #include "operators/resize.h" +#include "operators/rms_norm.h" #include "operators/rope.h" #include "operators/send.h" #include "operators/slice.h" @@ -124,6 +125,17 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale, } } +Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(input), std::move(weight), + output); + return output; + } else { + return g->addOp(std::move(input), std::move(weight), output) + ->getOutput(); + } +} + Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw, int ceilMode) { diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 361a3d50..832d7cc6 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -506,6 +506,7 @@ void init_graph_builder(py::module &m) { .def("matmul", &Handler::matmul, policy::move) .def("batchNormalization", &Handler::batchNormalization, policy::move) .def("layerNormalization", &Handler::layerNormalization, policy::move) + .def("RMSNorm", &Handler::rmsNorm, policy::move) .def("maxPool", &Handler::maxPool, policy::move) .def("avgPool", &Handler::avgPool, policy::move) .def("add", &Handler::add, policy::move) diff --git a/src/kernels/cuda/rms_norm.cc b/src/kernels/cuda/rms_norm.cc new file mode 100644 index 00000000..1bd6dfda --- /dev/null +++ b/src/kernels/cuda/rms_norm.cc @@ -0,0 +1,34 @@ +#include "operators/rms_norm.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_rmsnorm.h" +#include "cuda/cuda_runtime.h" + +namespace infini { + +class RMSNormCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + auto input = op->getInputs(0); + auto weight = op->getInputs(1); + auto output = op->getOutput(); + void *const inputData = input->getRawDataPtr(); + void *const weightData = weight->getRawDataPtr(); + void *const outputData = output->getRawDataPtr(); + const auto &inputShape = input->getDims(); + int nDims = input->getDims().size(); + + int hidden_size = inputShape[nDims - 1]; + int num_tokens = input->size() / hidden_size; + IT_ASSERT(hidden_size == (int)weight->size()); + + const int dType = op->getDType().getIndex(); + rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens, + hidden_size); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA"); + +} // namespace infini diff --git a/src/kernels/cuda/rms_norm.cu b/src/kernels/cuda/rms_norm.cu new file mode 100644 index 00000000..530a42ce --- /dev/null +++ b/src/kernels/cuda/rms_norm.cu @@ -0,0 +1,112 @@ +#include "core/common.h" +#include "cuda/cuda_common.h" +#include "cuda/cuda_utility.h" +#include "utils/small_array.h" + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(uint32_t(-1), val, mask); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +template +__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ + const float x = ((T*) in)[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + variance = blockReduceSum(variance); + if(threadIdx.x == 0){ + s_variance = rsqrtf(variance / hidden_size + 0.00001f); + } + __syncthreads(); + + for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){ + float x = ((T*) in)[blockIdx.x * hidden_size + idx]; + ((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx]; + } +} + + +#define CASE(T) \ + _rmsnorm_kernel::t> \ + <<>> \ + (input, weight, output, num_tokens, hidden_size); + +#define SWITCH_DTYPE(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE(1) \ + break; \ + case 2: \ + CASE(2) \ + break; \ + case 3: \ + CASE(3) \ + break; \ + case 4: \ + CASE(4) \ + break; \ + case 5: \ + CASE(5) \ + break; \ + case 6: \ + CASE(6) \ + break; \ + case 7: \ + CASE(7) \ + break; \ + case 10: \ + CASE(10) \ + break; \ + case 11: \ + CASE(11) \ + break; \ + case 12: \ + CASE(12) \ + break; \ + case 13: \ + CASE(13) \ + break; \ + case 16: \ + CASE(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + +namespace infini { +void rmsnorm_kernel(int dType, void *input, void *weight, void *output, + int num_tokens, int hidden_size) { + dim3 blocksize = dim3(std::min(hidden_size, 1024)); + dim3 gridsize = dim3(num_tokens); + SWITCH_DTYPE(dType) +} + +} // namespace infini diff --git a/src/kernels/cuda/rope.cc b/src/kernels/cuda/rope.cc index 1ec5cca2..27fc83f4 100644 --- a/src/kernels/cuda/rope.cc +++ b/src/kernels/cuda/rope.cc @@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig { IT_ASSERT(nDims == 3 && pos->getDims().size() == 2); IT_ASSERT(inputShape[1] == pos->getDims()[1]); int dim_model = inputShape[2]; - int dim_head = dim_model / 32; + int dim_head = 128; int hidden_stride = dim_model * inputShape[1]; int pos_stride = inputShape[1]; diff --git a/src/kernels/cuda/rope.cu b/src/kernels/cuda/rope.cu index 8d35026f..6e947f5c 100644 --- a/src/kernels/cuda/rope.cu +++ b/src/kernels/cuda/rope.cu @@ -3,11 +3,6 @@ #include "cuda/cuda_utility.h" #include "utils/small_array.h" -constexpr unsigned int num_threads() { return 32 * 4; } -constexpr int thread_work_size() { return 4; } -constexpr int block_work_size() { return thread_work_size() * num_threads(); } - -// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1) template __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { @@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo namespace infini { void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) { - dim3 blocksize = dim3(1024,1,1); - dim3 gridsize = dim3(1, 1, 4); + dim3 blocksize = dim3(32,1,1); + dim3 gridsize = dim3(1, 1, dim_model/32); SWITCH_DTYPE(dType) } diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 93a3cf6c..f4443564 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -315,6 +315,8 @@ void unary_kernel(const Operator &_op) { } else if (op->getOpType() == OpType::Silu) { if (_op->getDType() == DataType::Float32) { silu_kernel((float *)inputData, (float *)outputData, num); + } else if (_op->getDType() == DataType::Float16){ + silu_kernel((half *)inputData, (half *)outputData, num); } else { IT_TODO_HALT(); } diff --git a/src/operators/rms_norm.cc b/src/operators/rms_norm.cc new file mode 100644 index 00000000..f71612a6 --- /dev/null +++ b/src/operators/rms_norm.cc @@ -0,0 +1,36 @@ +#include "operators/rms_norm.h" + +namespace infini { +RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, + Tensor output) + : OperatorObj(OpType::RMSNorm, {input, weight}, {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> RMSNormObj::inferShape(const TensorVec &inputs) { + const auto A = inputs[0]; + auto input_dim = A->getDims(); + auto output_dim = input_dim; + return {{output_dim}}; +} + +std::string RMSNormObj::toString() const { + std::ostringstream os; + os << type.toString() << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector RMSNormObj::getWorkloadVector() const { + vector ret{type.underlying()}; + const Shape shape = outputs[0]->getDims(); + ret.insert(ret.end(), shape.begin(), shape.end()); + return ret; +} + +vector RMSNormObj::getOpAttrVector() const { return {type.underlying()}; } + +}; // namespace infini