forked from jiuyuan/InfiniTensor
Accelerate llama (#219)
* [feature] add cudagraph support * modify code to pass the cuda_all_reduce test * modify rope op * support rmsnorm * add fp16 support to silu cuda op * fix bugs in rmsnorm op * uncomment simplify in onnx.py --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
54a35772fb
commit
a98573990b
|
@ -37,6 +37,7 @@ class GraphHandlerObj {
|
||||||
float momentum, float eps, bool training);
|
float momentum, float eps, bool training);
|
||||||
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
|
||||||
Tensor bias, float eps, int axis, int stash_type);
|
Tensor bias, float eps, int axis, int stash_type);
|
||||||
|
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
|
||||||
|
|
||||||
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
|
||||||
int ph, int pw, int sh, int sw, int ceilMode);
|
int ph, int pw, int sh, int sw, int ceilMode);
|
||||||
|
|
|
@ -158,6 +158,7 @@ struct OpType {
|
||||||
RoiAlign,
|
RoiAlign,
|
||||||
RoPE, // Fusion
|
RoPE, // Fusion
|
||||||
Round, // Unary
|
Round, // Unary
|
||||||
|
RMSNorm, // Fusion
|
||||||
STFT,
|
STFT,
|
||||||
Scan,
|
Scan,
|
||||||
Scatter,
|
Scatter,
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "operators/rms_norm.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||||
|
int num_tokens, int hidden_size);
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,34 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
/**
|
||||||
|
* @brief Fused RMSNorm Operator
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class RMSNormObj : public OperatorObj {
|
||||||
|
int dim;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new RMSNorm object.
|
||||||
|
*
|
||||||
|
* @param graph The computation graph that this operator belongs to.
|
||||||
|
* @param input The input tensor.
|
||||||
|
* @param output The output tensor.
|
||||||
|
*/
|
||||||
|
RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output);
|
||||||
|
OP_CLONE(RMSNormObj);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 2; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
int getDim() const { return dim; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -285,6 +285,12 @@ class OnnxStub:
|
||||||
axis,
|
axis,
|
||||||
stash_type,
|
stash_type,
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "RMSNorm":
|
||||||
|
tensors[node.output[0]] = self.handler.RMSNorm(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
elif node.op_type == "MaxPool":
|
elif node.op_type == "MaxPool":
|
||||||
attributes = _parse_attribute(
|
attributes = _parse_attribute(
|
||||||
node,
|
node,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "operators/reduce.h"
|
#include "operators/reduce.h"
|
||||||
#include "operators/reshape.h"
|
#include "operators/reshape.h"
|
||||||
#include "operators/resize.h"
|
#include "operators/resize.h"
|
||||||
|
#include "operators/rms_norm.h"
|
||||||
#include "operators/rope.h"
|
#include "operators/rope.h"
|
||||||
#include "operators/send.h"
|
#include "operators/send.h"
|
||||||
#include "operators/slice.h"
|
#include "operators/slice.h"
|
||||||
|
@ -124,6 +125,17 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
|
||||||
|
if (output) {
|
||||||
|
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
|
||||||
|
output);
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
|
||||||
|
->getOutput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
|
||||||
int dh, int dw, int ph, int pw, int sh, int sw,
|
int dh, int dw, int ph, int pw, int sh, int sw,
|
||||||
int ceilMode) {
|
int ceilMode) {
|
||||||
|
|
|
@ -506,6 +506,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("matmul", &Handler::matmul, policy::move)
|
.def("matmul", &Handler::matmul, policy::move)
|
||||||
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
.def("batchNormalization", &Handler::batchNormalization, policy::move)
|
||||||
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
.def("layerNormalization", &Handler::layerNormalization, policy::move)
|
||||||
|
.def("RMSNorm", &Handler::rmsNorm, policy::move)
|
||||||
.def("maxPool", &Handler::maxPool, policy::move)
|
.def("maxPool", &Handler::maxPool, policy::move)
|
||||||
.def("avgPool", &Handler::avgPool, policy::move)
|
.def("avgPool", &Handler::avgPool, policy::move)
|
||||||
.def("add", &Handler::add, policy::move)
|
.def("add", &Handler::add, policy::move)
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
#include "operators/rms_norm.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_rmsnorm.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class RMSNormCuda : public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<RMSNormObj>(_op);
|
||||||
|
|
||||||
|
auto input = op->getInputs(0);
|
||||||
|
auto weight = op->getInputs(1);
|
||||||
|
auto output = op->getOutput();
|
||||||
|
void *const inputData = input->getRawDataPtr<void *>();
|
||||||
|
void *const weightData = weight->getRawDataPtr<void *>();
|
||||||
|
void *const outputData = output->getRawDataPtr<void *>();
|
||||||
|
const auto &inputShape = input->getDims();
|
||||||
|
int nDims = input->getDims().size();
|
||||||
|
|
||||||
|
int hidden_size = inputShape[nDims - 1];
|
||||||
|
int num_tokens = input->size() / hidden_size;
|
||||||
|
IT_ASSERT(hidden_size == (int)weight->size());
|
||||||
|
|
||||||
|
const int dType = op->getDType().getIndex();
|
||||||
|
rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens,
|
||||||
|
hidden_size);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA");
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,112 @@
|
||||||
|
#include "core/common.h"
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1)
|
||||||
|
val += __shfl_xor_sync(uint32_t(-1), val, mask);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Calculate the sum of all elements in a block */
|
||||||
|
template<class T>
|
||||||
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
|
static __shared__ T shared[32];
|
||||||
|
int lane = threadIdx.x & 0x1f;
|
||||||
|
int wid = threadIdx.x >> 5;
|
||||||
|
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
|
||||||
|
if (lane == 0)
|
||||||
|
shared[wid] = val;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||||
|
// blockDim.x is not divided by 32
|
||||||
|
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
|
||||||
|
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||||
|
const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||||
|
variance += x * x;
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if(threadIdx.x == 0){
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + 0.00001f);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
|
||||||
|
float x = ((T*) in)[blockIdx.x * hidden_size + idx];
|
||||||
|
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define CASE(T) \
|
||||||
|
_rmsnorm_kernel<DT_CUDA<T>::t> \
|
||||||
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
||||||
|
(input, weight, output, num_tokens, hidden_size);
|
||||||
|
|
||||||
|
#define SWITCH_DTYPE(DTYPE) \
|
||||||
|
switch (DTYPE) { \
|
||||||
|
case 1: \
|
||||||
|
CASE(1) \
|
||||||
|
break; \
|
||||||
|
case 2: \
|
||||||
|
CASE(2) \
|
||||||
|
break; \
|
||||||
|
case 3: \
|
||||||
|
CASE(3) \
|
||||||
|
break; \
|
||||||
|
case 4: \
|
||||||
|
CASE(4) \
|
||||||
|
break; \
|
||||||
|
case 5: \
|
||||||
|
CASE(5) \
|
||||||
|
break; \
|
||||||
|
case 6: \
|
||||||
|
CASE(6) \
|
||||||
|
break; \
|
||||||
|
case 7: \
|
||||||
|
CASE(7) \
|
||||||
|
break; \
|
||||||
|
case 10: \
|
||||||
|
CASE(10) \
|
||||||
|
break; \
|
||||||
|
case 11: \
|
||||||
|
CASE(11) \
|
||||||
|
break; \
|
||||||
|
case 12: \
|
||||||
|
CASE(12) \
|
||||||
|
break; \
|
||||||
|
case 13: \
|
||||||
|
CASE(13) \
|
||||||
|
break; \
|
||||||
|
case 16: \
|
||||||
|
CASE(16) \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
IT_TODO_HALT(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
|
||||||
|
int num_tokens, int hidden_size) {
|
||||||
|
dim3 blocksize = dim3(std::min(hidden_size, 1024));
|
||||||
|
dim3 gridsize = dim3(num_tokens);
|
||||||
|
SWITCH_DTYPE(dType)
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig {
|
||||||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
||||||
int dim_model = inputShape[2];
|
int dim_model = inputShape[2];
|
||||||
int dim_head = dim_model / 32;
|
int dim_head = 128;
|
||||||
int hidden_stride = dim_model * inputShape[1];
|
int hidden_stride = dim_model * inputShape[1];
|
||||||
int pos_stride = inputShape[1];
|
int pos_stride = inputShape[1];
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,6 @@
|
||||||
#include "cuda/cuda_utility.h"
|
#include "cuda/cuda_utility.h"
|
||||||
#include "utils/small_array.h"
|
#include "utils/small_array.h"
|
||||||
|
|
||||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
|
||||||
constexpr int thread_work_size() { return 4; }
|
|
||||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
|
||||||
|
|
||||||
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
|
|
||||||
template <class T>
|
template <class T>
|
||||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
|
||||||
int dim_head, int hidden_stride, int pos_stride) {
|
int dim_head, int hidden_stride, int pos_stride) {
|
||||||
|
@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
|
||||||
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||||
dim3 blocksize = dim3(1024,1,1);
|
dim3 blocksize = dim3(32,1,1);
|
||||||
dim3 gridsize = dim3(1, 1, 4);
|
dim3 gridsize = dim3(1, 1, dim_model/32);
|
||||||
SWITCH_DTYPE(dType)
|
SWITCH_DTYPE(dType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -315,6 +315,8 @@ void unary_kernel(const Operator &_op) {
|
||||||
} else if (op->getOpType() == OpType::Silu) {
|
} else if (op->getOpType() == OpType::Silu) {
|
||||||
if (_op->getDType() == DataType::Float32) {
|
if (_op->getDType() == DataType::Float32) {
|
||||||
silu_kernel<float>((float *)inputData, (float *)outputData, num);
|
silu_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||||
|
} else if (_op->getDType() == DataType::Float16){
|
||||||
|
silu_kernel<half>((half *)inputData, (half *)outputData, num);
|
||||||
} else {
|
} else {
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
#include "operators/rms_norm.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||||
|
Tensor output)
|
||||||
|
: OperatorObj(OpType::RMSNorm, {input, weight}, {output}) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> RMSNormObj::inferShape(const TensorVec &inputs) {
|
||||||
|
const auto A = inputs[0];
|
||||||
|
auto input_dim = A->getDims();
|
||||||
|
auto output_dim = input_dim;
|
||||||
|
return {{output_dim}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string RMSNormObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << type.toString() << "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> RMSNormObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret{type.underlying()};
|
||||||
|
const Shape shape = outputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> RMSNormObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||||
|
|
||||||
|
}; // namespace infini
|
Loading…
Reference in New Issue