Accelerate llama (#219)

* [feature] add cudagraph support

* modify code to pass the cuda_all_reduce test

* modify rope op

* support rmsnorm

* add fp16 support to silu cuda op

* fix bugs in rmsnorm op

* uncomment simplify in onnx.py

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
xiaonans 2024-04-01 08:46:05 +08:00 committed by GitHub
parent 54a35772fb
commit a98573990b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 254 additions and 10 deletions

View File

@ -37,6 +37,7 @@ class GraphHandlerObj {
float momentum, float eps, bool training); float momentum, float eps, bool training);
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output, Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
Tensor bias, float eps, int axis, int stash_type); Tensor bias, float eps, int axis, int stash_type);
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw, int ceilMode); int ph, int pw, int sh, int sw, int ceilMode);

View File

@ -156,8 +156,9 @@ struct OpType {
Resize, Resize,
ReverseSequence, ReverseSequence,
RoiAlign, RoiAlign,
RoPE, // Fusion RoPE, // Fusion
Round, // Unary Round, // Unary
RMSNorm, // Fusion
STFT, STFT,
Scan, Scan,
Scatter, Scatter,

View File

@ -0,0 +1,10 @@
#pragma once
#include "operators/rms_norm.h"
namespace infini {
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
int num_tokens, int hidden_size);
}; // namespace infini

View File

@ -0,0 +1,34 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Fused RMSNorm Operator
*
*/
class RMSNormObj : public OperatorObj {
int dim;
public:
/**
* @brief Construct a new RMSNorm object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
*/
RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output);
OP_CLONE(RMSNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
int getDim() const { return dim; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -285,6 +285,12 @@ class OnnxStub:
axis, axis,
stash_type, stash_type,
) )
elif node.op_type == "RMSNorm":
tensors[node.output[0]] = self.handler.RMSNorm(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "MaxPool": elif node.op_type == "MaxPool":
attributes = _parse_attribute( attributes = _parse_attribute(
node, node,

View File

@ -18,6 +18,7 @@
#include "operators/reduce.h" #include "operators/reduce.h"
#include "operators/reshape.h" #include "operators/reshape.h"
#include "operators/resize.h" #include "operators/resize.h"
#include "operators/rms_norm.h"
#include "operators/rope.h" #include "operators/rope.h"
#include "operators/send.h" #include "operators/send.h"
#include "operators/slice.h" #include "operators/slice.h"
@ -124,6 +125,17 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
} }
} }
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
if (output) {
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
output);
return output;
} else {
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
->getOutput();
}
}
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw, Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw, int dh, int dw, int ph, int pw, int sh, int sw,
int ceilMode) { int ceilMode) {

View File

@ -506,6 +506,7 @@ void init_graph_builder(py::module &m) {
.def("matmul", &Handler::matmul, policy::move) .def("matmul", &Handler::matmul, policy::move)
.def("batchNormalization", &Handler::batchNormalization, policy::move) .def("batchNormalization", &Handler::batchNormalization, policy::move)
.def("layerNormalization", &Handler::layerNormalization, policy::move) .def("layerNormalization", &Handler::layerNormalization, policy::move)
.def("RMSNorm", &Handler::rmsNorm, policy::move)
.def("maxPool", &Handler::maxPool, policy::move) .def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move) .def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move) .def("add", &Handler::add, policy::move)

View File

@ -0,0 +1,34 @@
#include "operators/rms_norm.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_rmsnorm.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class RMSNormCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<RMSNormObj>(_op);
auto input = op->getInputs(0);
auto weight = op->getInputs(1);
auto output = op->getOutput();
void *const inputData = input->getRawDataPtr<void *>();
void *const weightData = weight->getRawDataPtr<void *>();
void *const outputData = output->getRawDataPtr<void *>();
const auto &inputShape = input->getDims();
int nDims = input->getDims().size();
int hidden_size = inputShape[nDims - 1];
int num_tokens = input->size() / hidden_size;
IT_ASSERT(hidden_size == (int)weight->size());
const int dType = op->getDType().getIndex();
rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens,
hidden_size);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA");
} // namespace infini

View File

@ -0,0 +1,112 @@
#include "core/common.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
template<class T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(uint32_t(-1), val, mask);
return val;
}
/* Calculate the sum of all elements in a block */
template<class T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <class T>
__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if(threadIdx.x == 0){
s_variance = rsqrtf(variance / hidden_size + 0.00001f);
}
__syncthreads();
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
float x = ((T*) in)[blockIdx.x * hidden_size + idx];
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
}
}
#define CASE(T) \
_rmsnorm_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(input, weight, output, num_tokens, hidden_size);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
int num_tokens, int hidden_size) {
dim3 blocksize = dim3(std::min(hidden_size, 1024));
dim3 gridsize = dim3(num_tokens);
SWITCH_DTYPE(dType)
}
} // namespace infini

View File

@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig {
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2); IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
IT_ASSERT(inputShape[1] == pos->getDims()[1]); IT_ASSERT(inputShape[1] == pos->getDims()[1]);
int dim_model = inputShape[2]; int dim_model = inputShape[2];
int dim_head = dim_model / 32; int dim_head = 128;
int hidden_stride = dim_model * inputShape[1]; int hidden_stride = dim_model * inputShape[1];
int pos_stride = inputShape[1]; int pos_stride = inputShape[1];

View File

@ -3,11 +3,6 @@
#include "cuda/cuda_utility.h" #include "cuda/cuda_utility.h"
#include "utils/small_array.h" #include "utils/small_array.h"
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
template <class T> template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
int dim_head, int hidden_stride, int pos_stride) { int dim_head, int hidden_stride, int pos_stride) {
@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
namespace infini { namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size, void rope_kernel(int dType, int * pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride, int pos_stride) { int dim_model, int dim_head, int hidden_stride, int pos_stride) {
dim3 blocksize = dim3(1024,1,1); dim3 blocksize = dim3(32,1,1);
dim3 gridsize = dim3(1, 1, 4); dim3 gridsize = dim3(1, 1, dim_model/32);
SWITCH_DTYPE(dType) SWITCH_DTYPE(dType)
} }

View File

@ -315,6 +315,8 @@ void unary_kernel(const Operator &_op) {
} else if (op->getOpType() == OpType::Silu) { } else if (op->getOpType() == OpType::Silu) {
if (_op->getDType() == DataType::Float32) { if (_op->getDType() == DataType::Float32) {
silu_kernel<float>((float *)inputData, (float *)outputData, num); silu_kernel<float>((float *)inputData, (float *)outputData, num);
} else if (_op->getDType() == DataType::Float16){
silu_kernel<half>((half *)inputData, (half *)outputData, num);
} else { } else {
IT_TODO_HALT(); IT_TODO_HALT();
} }

36
src/operators/rms_norm.cc Normal file
View File

@ -0,0 +1,36 @@
#include "operators/rms_norm.h"
namespace infini {
RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight,
Tensor output)
: OperatorObj(OpType::RMSNorm, {input, weight}, {output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> RMSNormObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0];
auto input_dim = A->getDims();
auto output_dim = input_dim;
return {{output_dim}};
}
std::string RMSNormObj::toString() const {
std::ostringstream os;
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> RMSNormObj::getWorkloadVector() const {
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> RMSNormObj::getOpAttrVector() const { return {type.underlying()}; }
}; // namespace infini