add rope and silu support

This commit is contained in:
xiaonans 2024-01-11 15:44:07 +08:00
parent afed5d3c3d
commit e8d111ef5d
14 changed files with 257 additions and 1 deletions

View File

@ -47,6 +47,7 @@ class GraphHandlerObj {
Tensor max(Tensor a, Tensor b, Tensor c); Tensor max(Tensor a, Tensor b, Tensor c);
Tensor relu(Tensor x, Tensor y); Tensor relu(Tensor x, Tensor y);
Tensor silu(Tensor x, Tensor y);
Tensor gelu(Tensor x, Tensor y); Tensor gelu(Tensor x, Tensor y);
Tensor sigmoid(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y);
Tensor hardSigmoid(Tensor x, Tensor y); Tensor hardSigmoid(Tensor x, Tensor y);
@ -77,6 +78,7 @@ class GraphHandlerObj {
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache, Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
Tensor input_q, Tensor input_k, Tensor input_v, Tensor input_q, Tensor input_k, Tensor input_v,
Tensor position_id, Tensor output_matmul); Tensor position_id, Tensor output_matmul);
Tensor RoPE(Tensor pos, Tensor input, Tensor output);
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis, TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
std::variant<int, vector<int>> numOrRatio); std::variant<int, vector<int>> numOrRatio);
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis); Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);

View File

@ -151,10 +151,12 @@ struct OpType {
ReduceSum, // Reduce ReduceSum, // Reduce
ReduceSumSquare, // Reduce ReduceSumSquare, // Reduce
Relu, // Unary Relu, // Unary
Silu, // Unary
Reshape, Reshape,
Resize, Resize,
ReverseSequence, ReverseSequence,
RoiAlign, RoiAlign,
RoPE, // Fusion
Round, // Unary Round, // Unary
STFT, STFT,
Scan, Scan,

10
include/cuda/cuda_rope.h Normal file
View File

@ -0,0 +1,10 @@
#pragma once
#include "operators/rope.h"
#include "utils/small_array.h"
namespace infini {
void rope_kernel(int dType, int* pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride);
}; // namespace infini

View File

@ -5,6 +5,7 @@
namespace infini { namespace infini {
template <typename T> void softmax_kernel(T *input, T *output, size_t num); template <typename T> void softmax_kernel(T *input, T *output, size_t num);
template <typename T> void relu_kernel(T *input, T *output, size_t num); template <typename T> void relu_kernel(T *input, T *output, size_t num);
template <typename T> void silu_kernel(T *input, T *output, size_t num);
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num); template <typename T> void sigmoid_kernel(T *input, T *output, size_t num);
template <typename T> void tanh_kernel(T *input, T *output, size_t num); template <typename T> void tanh_kernel(T *input, T *output, size_t num);
template <typename T> void abs_kernel(T *input, T *output, size_t num); template <typename T> void abs_kernel(T *input, T *output, size_t num);

21
include/operators/rope.h Normal file
View File

@ -0,0 +1,21 @@
#pragma once
#include "core/operator.h"
namespace infini {
class RoPEObj : public OperatorObj {
public:
RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output);
OP_CLONE(RoPEObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
DataType getDType() const { return getInputs(1)->getDType(); }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -258,6 +258,7 @@ class LogObj : public OperatorObj {
}; };
DEFINE_UNARY_OBJ(Relu, OpType::Relu) DEFINE_UNARY_OBJ(Relu, OpType::Relu)
DEFINE_UNARY_OBJ(Silu, OpType::Silu)
DEFINE_UNARY_OBJ(Gelu, OpType::Gelu) DEFINE_UNARY_OBJ(Gelu, OpType::Gelu)
DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid) DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh) DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)

View File

@ -438,6 +438,11 @@ class OnnxStub:
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
) )
elif node.op_type == "Silu":
tensors[node.output[0]] = self.handler.silu(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Gelu": elif node.op_type == "Gelu":
tensors[node.output[0]] = self.handler.gelu( tensors[node.output[0]] = self.handler.gelu(
tensors[node.input[0]], tensors[node.input[0]],
@ -669,6 +674,12 @@ class OnnxStub:
tensors[node.input[5]], tensors[node.input[5]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
) )
elif node.op_type == "RoPE":
tensors[node.output[0]]= self.handler.RoPE(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Split": elif node.op_type == "Split":
split = ( split = (
_parse_data(data[node.input[1]]) _parse_data(data[node.input[1]])

View File

@ -2,6 +2,7 @@
#include "operators/all_gather.h" #include "operators/all_gather.h"
#include "operators/all_reduce.h" #include "operators/all_reduce.h"
#include "operators/attention_kvcache.h" #include "operators/attention_kvcache.h"
#include "operators/rope.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/broadcast.h" #include "operators/broadcast.h"
#include "operators/concat.h" #include "operators/concat.h"
@ -180,7 +181,8 @@ DEFINE_ELEMENT_WISE_METHOD(max, Maximum)
return g->addOp<obj##Obj>(std::move(x), y)->getOutput(); \ return g->addOp<obj##Obj>(std::move(x), y)->getOutput(); \
} \ } \
} }
DEFINE_UNARY_METHOD(silu, Silu)
DEFINE_UNARY_METHOD(relu, Relu) DEFINE_UNARY_METHOD(relu, Relu)
DEFINE_UNARY_METHOD(gelu, Gelu) DEFINE_UNARY_METHOD(gelu, Gelu)
DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
@ -345,6 +347,16 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
} }
} }
Tensor GraphHandlerObj::RoPE(Tensor pos, Tensor input, Tensor output) {
if (output) {
g->addOpWithOutputs<RoPEObj>(std::move(pos), std::move(input), output);
return output;
} else {
return g->addOp<RoPEObj>(std::move(pos), std::move(input), output)
->getOutput();
}
}
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs, TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
int axis, int axis,
std::variant<int, vector<int>> numOrRatio) { std::variant<int, vector<int>> numOrRatio) {

View File

@ -515,6 +515,7 @@ void init_graph_builder(py::module &m) {
.def("min", &Handler::min, policy::move) .def("min", &Handler::min, policy::move)
.def("max", &Handler::max, policy::move) .def("max", &Handler::max, policy::move)
.def("relu", &Handler::relu, policy::move) .def("relu", &Handler::relu, policy::move)
.def("silu", &Handler::silu, policy::move)
.def("gelu", &Handler::gelu, policy::move) .def("gelu", &Handler::gelu, policy::move)
.def("sigmoid", &Handler::sigmoid, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move)
.def("tanh", &Handler::tanh, policy::move) .def("tanh", &Handler::tanh, policy::move)
@ -537,6 +538,7 @@ void init_graph_builder(py::module &m) {
.def("unsqueeze", &Handler::unsqueeze, policy::move) .def("unsqueeze", &Handler::unsqueeze, policy::move)
.def("concat", &Handler::concat, policy::move) .def("concat", &Handler::concat, policy::move)
.def("attentionKVCache", &Handler::attentionKVCache, policy::move) .def("attentionKVCache", &Handler::attentionKVCache, policy::move)
.def("RoPE", &Handler::RoPE, policy::move)
.def("split", &Handler::split, policy::move) .def("split", &Handler::split, policy::move)
.def("gather", &Handler::gather, policy::move) .def("gather", &Handler::gather, policy::move)
.def("gatherElements", &Handler::gatherElements, policy::move) .def("gatherElements", &Handler::gatherElements, policy::move)

38
src/kernels/cuda/rope.cc Normal file
View File

@ -0,0 +1,38 @@
#include "operators/rope.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_rope.h"
namespace infini {
class RoPECuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<RoPEObj>(_op);
auto pos = op->getInputs(0);
auto input = op->getInputs(1);
auto output = op->getOutput();
void *const inputData = input->getRawDataPtr<void *>();
void *const outputData = output->getRawDataPtr<void *>();
const auto &inputShape = input->getDims();
int nDims = input->getDims().size();
int size = input->size();
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
int dim_model = inputShape[2];
int dim_head = dim_model / 32;
int hidden_stride = dim_model * inputShape[1];
int pos_stride = inputShape[1];
const int dType = op->getDType().getIndex();
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData, size, dim_model, dim_head, hidden_stride, pos_stride);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda,
"RoPE_CUDA");
} // namespace infini

91
src/kernels/cuda/rope.cu Normal file
View File

@ -0,0 +1,91 @@
#include "core/common.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) {
int batch_id = blockIdx.x;
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
int ith = blockIdx.z * blockDim.x + threadIdx.x;
int col = ith % dim_head;
int offset = batch_id * hidden_stride + blockIdx.y * dim_model;
if (ith >= dim_model)
return;
int half_dim = dim_head / 2;
if (col < half_dim) {
float freq = target_pos * powf(10000, -float(col * 2) / dim_head);
float cos_freq = cos(freq);
float sin_freq = sin(freq);
((T *)out)[offset + ith] =
((T *)in)[offset + ith] * T(cos_freq) - ((T *)in)[offset + ith + half_dim] * T(sin_freq);
} else {
float freq = target_pos * powf(10000, -float((col - half_dim) * 2) / dim_head);
float cos_freq = cos(freq);
float sin_freq = sin(freq);
((T *)out)[offset + ith] =
((T *)in)[offset + ith] * T(cos_freq) + ((T *)in)[offset + ith - half_dim] * T(sin_freq);
}
}
#define CASE(T) \
_rope_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) {
dim3 blocksize = dim3(1024,1,1);
dim3 gridsize = dim3(1, 1, 4);
SWITCH_DTYPE(dType)
}
} // namespace infini

View File

@ -157,6 +157,7 @@ class SoftmaxCudnn : public CudaKernelWithoutConfig {
class ReluCudnn : public ActivationCudnn { class ReluCudnn : public ActivationCudnn {
cudnnActivationMode_t getOpType() const override { cudnnActivationMode_t getOpType() const override {
return CUDNN_ACTIVATION_RELU; return CUDNN_ACTIVATION_RELU;
} }
}; };
@ -182,6 +183,7 @@ REGISTER_KERNEL(Device::CUDA, OpType::Tanh, TanhCudnn, "Tanh_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Abs, UnaryCuda, "Abs_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Abs, UnaryCuda, "Abs_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, UnaryCuda, "Sqrt_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, UnaryCuda, "Sqrt_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Silu, UnaryCuda, "Silu_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA");

View File

@ -103,6 +103,17 @@ __global__ void _gelu_kernel(T *input, T *output, size_t n) {
output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f))); output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f)));
} }
} }
template <typename T>
__global__ void _silu_kernel(T *input, T *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
float x = input[i];
output[i] = x / (1.0 + expf(-x));;
}
}
template <typename T> template <typename T>
__global__ void _erf_kernel(T *input, T *output, size_t n) { __global__ void _erf_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t index = threadIdx.x + blockIdx.x * blockDim.x;
@ -190,6 +201,14 @@ template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_gelu_kernel<T><<<gridsize, blocksize>>>(input, output, num); _gelu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
} }
template <typename T> void silu_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_silu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
template <typename T> void erf_kernel(T *input, T *output, size_t num) { template <typename T> void erf_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size(); int blocksize = block_work_size();
@ -209,6 +228,7 @@ void unary_kernel(const Operator &_op) {
void *const outputData = (op->getOutput()->getRawDataPtr<void *>()); void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
size_t num = op->getOutput()->size(); size_t num = op->getOutput()->size();
if (op->getOpType() == OpType::Softmax) { if (op->getOpType() == OpType::Softmax) {
if (_op->getDType() == DataType::Float32) { if (_op->getDType() == DataType::Float32) {
softmax_kernel<float>((float *)inputData, (float *)outputData, num); softmax_kernel<float>((float *)inputData, (float *)outputData, num);
@ -267,6 +287,12 @@ void unary_kernel(const Operator &_op) {
} else { } else {
IT_TODO_HALT(); IT_TODO_HALT();
} }
} else if (op->getOpType() == OpType::Silu) {
if (_op->getDType() == DataType::Float32) {
silu_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Neg) { } else if (op->getOpType() == OpType::Neg) {
if (_op->getDType() == DataType::Float32) { if (_op->getDType() == DataType::Float32) {
neg_kernel<float>((float *)inputData, (float *)outputData, num); neg_kernel<float>((float *)inputData, (float *)outputData, num);

37
src/operators/rope.cc Normal file
View File

@ -0,0 +1,37 @@
#include "operators/rope.h"
namespace infini {
RoPEObj::RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output)
: OperatorObj(OpType::RoPE, {pos, input}, {output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> RoPEObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[1];
auto input_dim = A->getDims();
auto output_dim = input_dim;
return {{output_dim}};
}
std::string RoPEObj::toString() const {
std::ostringstream os;
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> RoPEObj::getWorkloadVector() const {
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> RoPEObj::getOpAttrVector() const {
return {type.underlying()};
}
}; // namespace infini