forked from jiuyuan/InfiniTensor
support rmsnorm
This commit is contained in:
parent
17bd98d453
commit
936797b960
|
@ -36,6 +36,7 @@ class GraphHandlerObj {
|
|||
float momentum, float eps, bool training);
|
||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||
Tensor bias, float eps, int axis, int stash_type);
|
||||
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
|
||||
|
||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw, int ceilMode);
|
||||
|
|
|
@ -158,6 +158,7 @@ struct OpType {
|
|||
RoiAlign,
|
||||
RoPE, // Fusion
|
||||
Round, // Unary
|
||||
RMSNorm, // Fusion
|
||||
STFT,
|
||||
Scan,
|
||||
Scatter,
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/rms_norm.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||
int num_tokens, int hidden_size);
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Fused RMSNorm Operator
|
||||
*
|
||||
*/
|
||||
class RMSNormObj : public OperatorObj {
|
||||
int dim;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new RMSNorm object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
*/
|
||||
RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output);
|
||||
OP_CLONE(RMSNormObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
int getDim() const { return dim; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -277,6 +277,12 @@ class OnnxStub:
|
|||
axis,
|
||||
stash_type,
|
||||
)
|
||||
elif node.op_type == "RMSNorm":
|
||||
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "MaxPool":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/resize.h"
|
||||
#include "operators/rms_norm.h"
|
||||
#include "operators/rope.h"
|
||||
#include "operators/send.h"
|
||||
#include "operators/slice.h"
|
||||
|
@ -122,6 +123,16 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight), output);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||
int ceilMode) {
|
||||
|
|
|
@ -504,6 +504,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("matmul", &Handler::matmul, policy::move)
|
||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||
.def("RMSNorm", &Handler::rmsNorm, policy::move)
|
||||
.def("maxPool", &Handler::maxPool, policy::move)
|
||||
.def("avgPool", &Handler::avgPool, policy::move)
|
||||
.def("add", &Handler::add, policy::move)
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
#include "operators/rms_norm.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_rmsnorm.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class RMSNormCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<RMSNormObj>(_op);
|
||||
|
||||
auto input = op->getInputs(0);
|
||||
auto weight = op->getInputs(1);
|
||||
auto output = op->getOutput();
|
||||
void *const inputData = input->getRawDataPtr<void *>();
|
||||
void *const weightData = weight->getRawDataPtr<void *>();
|
||||
void *const outputData = output->getRawDataPtr<void *>();
|
||||
const auto &inputShape = input->getDims();
|
||||
int nDims = input->getDims().size();
|
||||
|
||||
int hidden_size = inputShape[nDims - 1];
|
||||
int num_tokens = input->size() / hidden_size;
|
||||
IT_ASSERT(hidden_size == (int)weight->size());
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens,
|
||||
hidden_size);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,112 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
template<class T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(uint32_t(-1), val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template<class T>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
const float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if(threadIdx.x == 0){
|
||||
s_variance = rsqrtf(variance / hidden_size + 0.00001f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||
float x = ((float*) in)[blockIdx.x * hidden_size + idx];
|
||||
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define CASE(T) \
|
||||
_rmsnorm_kernel<DT_CUDA<T>::t> \
|
||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||
(input, weight, output, num_tokens, hidden_size);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||
int num_tokens, int hidden_size) {
|
||||
dim3 blocksize = dim3(std::min(hidden_size, 1024));
|
||||
dim3 gridsize = dim3(num_tokens);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,36 @@
|
|||
#include "operators/rms_norm.h"
|
||||
|
||||
namespace infini {
|
||||
RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
Tensor output)
|
||||
: OperatorObj(OpType::RMSNorm, {input, weight}, {output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> RMSNormObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
auto input_dim = A->getDims();
|
||||
auto output_dim = input_dim;
|
||||
return {{output_dim}};
|
||||
}
|
||||
|
||||
std::string RMSNormObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << type.toString() << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> RMSNormObj::getWorkloadVector() const {
|
||||
vector<int> ret{type.underlying()};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> RMSNormObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
}; // namespace infini
|
Loading…
Reference in New Issue