support rmsnorm

This commit is contained in:
xiaonans 2024-02-08 10:40:03 +08:00
parent 17bd98d453
commit 936797b960
10 changed files with 246 additions and 0 deletions

View File

@ -36,6 +36,7 @@ class GraphHandlerObj {
float momentum, float eps, bool training);
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
Tensor bias, float eps, int axis, int stash_type);
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw, int ceilMode);

View File

@ -158,6 +158,7 @@ struct OpType {
RoiAlign,
RoPE, // Fusion
Round, // Unary
RMSNorm, // Fusion
STFT,
Scan,
Scatter,

View File

@ -0,0 +1,10 @@
#pragma once
#include "operators/rms_norm.h"
namespace infini {
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
int num_tokens, int hidden_size);
}; // namespace infini

View File

@ -0,0 +1,34 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Fused RMSNorm Operator
*
*/
class RMSNormObj : public OperatorObj {
int dim;
public:
/**
* @brief Construct a new RMSNorm object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
*/
RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output);
OP_CLONE(RMSNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
int getDim() const { return dim; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -277,6 +277,12 @@ class OnnxStub:
axis,
stash_type,
)
elif node.op_type == "RMSNorm":
tensors[node.output[0]] = self.handler.RMSNorm(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "MaxPool":
attributes = _parse_attribute(
node,

View File

@ -18,6 +18,7 @@
#include "operators/reduce.h"
#include "operators/reshape.h"
#include "operators/resize.h"
#include "operators/rms_norm.h"
#include "operators/rope.h"
#include "operators/send.h"
#include "operators/slice.h"
@ -122,6 +123,16 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
}
}
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
if (output) {
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight), output);
return output;
} else {
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
->getOutput();
}
}
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw,
int ceilMode) {

View File

@ -504,6 +504,7 @@ void init_graph_builder(py::module &m) {
.def("matmul", &Handler::matmul, policy::move)
.def("batchNormalization", &Handler::batchNormalization, policy::move)
.def("layerNormalization", &Handler::layerNormalization, policy::move)
.def("RMSNorm", &Handler::rmsNorm, policy::move)
.def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move)

View File

@ -0,0 +1,34 @@
#include "operators/rms_norm.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_rmsnorm.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class RMSNormCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<RMSNormObj>(_op);
auto input = op->getInputs(0);
auto weight = op->getInputs(1);
auto output = op->getOutput();
void *const inputData = input->getRawDataPtr<void *>();
void *const weightData = weight->getRawDataPtr<void *>();
void *const outputData = output->getRawDataPtr<void *>();
const auto &inputShape = input->getDims();
int nDims = input->getDims().size();
int hidden_size = inputShape[nDims - 1];
int num_tokens = input->size() / hidden_size;
IT_ASSERT(hidden_size == (int)weight->size());
const int dType = op->getDType().getIndex();
rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens,
hidden_size);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA");
} // namespace infini

View File

@ -0,0 +1,112 @@
#include "core/common.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
template<class T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(uint32_t(-1), val, mask);
return val;
}
/* Calculate the sum of all elements in a block */
template<class T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <class T>
__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
const float x = ((float*) in)[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if(threadIdx.x == 0){
s_variance = rsqrtf(variance / hidden_size + 0.00001f);
}
__syncthreads();
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
float x = ((float*) in)[blockIdx.x * hidden_size + idx];
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
}
}
#define CASE(T) \
_rmsnorm_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(input, weight, output, num_tokens, hidden_size);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
int num_tokens, int hidden_size) {
dim3 blocksize = dim3(std::min(hidden_size, 1024));
dim3 gridsize = dim3(num_tokens);
SWITCH_DTYPE(dType)
}
} // namespace infini

36
src/operators/rms_norm.cc Normal file
View File

@ -0,0 +1,36 @@
#include "operators/rms_norm.h"
namespace infini {
RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight,
Tensor output)
: OperatorObj(OpType::RMSNorm, {input, weight}, {output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> RMSNormObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0];
auto input_dim = A->getDims();
auto output_dim = input_dim;
return {{output_dim}};
}
std::string RMSNormObj::toString() const {
std::ostringstream os;
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> RMSNormObj::getWorkloadVector() const {
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> RMSNormObj::getOpAttrVector() const { return {type.underlying()}; }
}; // namespace infini