forked from jiuyuan/InfiniTensor
commit
900d8e58e3
|
@ -47,6 +47,7 @@ class GraphHandlerObj {
|
|||
Tensor max(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor relu(Tensor x, Tensor y);
|
||||
Tensor silu(Tensor x, Tensor y);
|
||||
Tensor gelu(Tensor x, Tensor y);
|
||||
Tensor sigmoid(Tensor x, Tensor y);
|
||||
Tensor hardSigmoid(Tensor x, Tensor y);
|
||||
|
@ -77,6 +78,7 @@ class GraphHandlerObj {
|
|||
Tensor attentionKVCache(Tensor input_k_cache, Tensor input_v_cache,
|
||||
Tensor input_q, Tensor input_k, Tensor input_v,
|
||||
Tensor position_id, Tensor output_matmul);
|
||||
Tensor RoPE(Tensor pos, Tensor input, Tensor output);
|
||||
TensorVec split(Tensor input, std::optional<TensorVec> outputs, int axis,
|
||||
std::variant<int, vector<int>> numOrRatio);
|
||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
|
|
|
@ -151,10 +151,12 @@ struct OpType {
|
|||
ReduceSum, // Reduce
|
||||
ReduceSumSquare, // Reduce
|
||||
Relu, // Unary
|
||||
Silu, // Unary
|
||||
Reshape,
|
||||
Resize,
|
||||
ReverseSequence,
|
||||
RoiAlign,
|
||||
RoPE, // Fusion
|
||||
Round, // Unary
|
||||
STFT,
|
||||
Scan,
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/rope.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void rope_kernel(int dType, int *pos, void *input, void *output, int size,
|
||||
int dim_model, int dim_head, int hidden_stride,
|
||||
int pos_stride);
|
||||
|
||||
}; // namespace infini
|
|
@ -5,6 +5,7 @@
|
|||
namespace infini {
|
||||
template <typename T> void softmax_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void relu_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void silu_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void tanh_kernel(T *input, T *output, size_t num);
|
||||
template <typename T> void abs_kernel(T *input, T *output, size_t num);
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class RoPEObj : public OperatorObj {
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new RotaryEmbedding object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param pos The positon id of the query.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
*/
|
||||
RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output);
|
||||
OP_CLONE(RoPEObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
DataType getDType() const { return getInputs(1)->getDType(); }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -258,6 +258,7 @@ class LogObj : public OperatorObj {
|
|||
};
|
||||
|
||||
DEFINE_UNARY_OBJ(Relu, OpType::Relu)
|
||||
DEFINE_UNARY_OBJ(Silu, OpType::Silu)
|
||||
DEFINE_UNARY_OBJ(Gelu, OpType::Gelu)
|
||||
DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
|
||||
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
||||
|
|
|
@ -438,6 +438,11 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Silu":
|
||||
tensors[node.output[0]] = self.handler.silu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Gelu":
|
||||
tensors[node.output[0]] = self.handler.gelu(
|
||||
tensors[node.input[0]],
|
||||
|
@ -669,6 +674,12 @@ class OnnxStub:
|
|||
tensors[node.input[5]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "RoPE":
|
||||
tensors[node.output[0]]= self.handler.RoPE(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Split":
|
||||
split = (
|
||||
_parse_data(data[node.input[1]])
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "operators/reduce.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/resize.h"
|
||||
#include "operators/rope.h"
|
||||
#include "operators/send.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/softmax.h"
|
||||
|
@ -181,6 +182,7 @@ DEFINE_ELEMENT_WISE_METHOD(max, Maximum)
|
|||
} \
|
||||
}
|
||||
|
||||
DEFINE_UNARY_METHOD(silu, Silu)
|
||||
DEFINE_UNARY_METHOD(relu, Relu)
|
||||
DEFINE_UNARY_METHOD(gelu, Gelu)
|
||||
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||
|
@ -345,6 +347,16 @@ Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::RoPE(Tensor pos, Tensor input, Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<RoPEObj>(std::move(pos), std::move(input), output);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<RoPEObj>(std::move(pos), std::move(input), output)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
int axis,
|
||||
std::variant<int, vector<int>> numOrRatio) {
|
||||
|
|
|
@ -515,6 +515,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("min", &Handler::min, policy::move)
|
||||
.def("max", &Handler::max, policy::move)
|
||||
.def("relu", &Handler::relu, policy::move)
|
||||
.def("silu", &Handler::silu, policy::move)
|
||||
.def("gelu", &Handler::gelu, policy::move)
|
||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||
.def("tanh", &Handler::tanh, policy::move)
|
||||
|
@ -537,6 +538,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("unsqueeze", &Handler::unsqueeze, policy::move)
|
||||
.def("concat", &Handler::concat, policy::move)
|
||||
.def("attentionKVCache", &Handler::attentionKVCache, policy::move)
|
||||
.def("RoPE", &Handler::RoPE, policy::move)
|
||||
.def("split", &Handler::split, policy::move)
|
||||
.def("gather", &Handler::gather, policy::move)
|
||||
.def("gatherElements", &Handler::gatherElements, policy::move)
|
||||
|
|
|
@ -47,6 +47,10 @@ class NativeUnary : public CpuKernelWithoutConfig {
|
|||
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
|
||||
}
|
||||
|
||||
template <typename T> static T siluCompute(T val) {
|
||||
return val / (1 + pow(E_CONSTANT, -val));
|
||||
}
|
||||
|
||||
template <typename T> static T erfCompute(T val) { return std::erf(val); }
|
||||
|
||||
template <typename T> static T aCosCompute(T val) { return std::acos(val); }
|
||||
|
@ -84,6 +88,9 @@ class NativeUnary : public CpuKernelWithoutConfig {
|
|||
case OpType::Gelu:
|
||||
_doCompute = geluCompute<T>;
|
||||
break;
|
||||
case OpType::Silu:
|
||||
_doCompute = siluCompute<T>;
|
||||
break;
|
||||
case OpType::Sigmoid:
|
||||
_doCompute = sigmoidCompute<T>;
|
||||
break;
|
||||
|
@ -289,6 +296,7 @@ class Log : public CpuKernelWithoutConfig {
|
|||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Silu, NativeUnary, "siluNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary,
|
||||
"hardSigmoidNaive_CPU");
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#include "operators/rope.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_rope.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class RoPECuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<RoPEObj>(_op);
|
||||
|
||||
auto pos = op->getInputs(0);
|
||||
auto input = op->getInputs(1);
|
||||
auto output = op->getOutput();
|
||||
void *const inputData = input->getRawDataPtr<void *>();
|
||||
void *const outputData = output->getRawDataPtr<void *>();
|
||||
const auto &inputShape = input->getDims();
|
||||
int nDims = input->getDims().size();
|
||||
|
||||
int size = input->size();
|
||||
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
|
||||
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
|
||||
int dim_model = inputShape[2];
|
||||
int dim_head = dim_model / 32;
|
||||
int hidden_stride = dim_model * inputShape[1];
|
||||
int pos_stride = inputShape[1];
|
||||
|
||||
const int dType = op->getDType().getIndex();
|
||||
rope_kernel(dType, pos->getRawDataPtr<int *>(), inputData, outputData,
|
||||
size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::RoPE, RoPECuda, "RoPE_CUDA");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,91 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
|
||||
template <class T>
|
||||
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
int batch_id = blockIdx.x;
|
||||
int target_pos = pos[batch_id * pos_stride + blockIdx.y];
|
||||
int ith = blockIdx.z * blockDim.x + threadIdx.x;
|
||||
int col = ith % dim_head;
|
||||
int offset = batch_id * hidden_stride + blockIdx.y * dim_model;
|
||||
|
||||
if (ith >= dim_model)
|
||||
return;
|
||||
int half_dim = dim_head / 2;
|
||||
if (col < half_dim) {
|
||||
float freq = target_pos * powf(10000, -float(col * 2) / dim_head);
|
||||
float cos_freq = cos(freq);
|
||||
float sin_freq = sin(freq);
|
||||
((T *)out)[offset + ith] =
|
||||
((T *)in)[offset + ith] * T(cos_freq) - ((T *)in)[offset + ith + half_dim] * T(sin_freq);
|
||||
} else {
|
||||
float freq = target_pos * powf(10000, -float((col - half_dim) * 2) / dim_head);
|
||||
float cos_freq = cos(freq);
|
||||
float sin_freq = sin(freq);
|
||||
((T *)out)[offset + ith] =
|
||||
((T *)in)[offset + ith] * T(cos_freq) + ((T *)in)[offset + ith - half_dim] * T(sin_freq);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define CASE(T) \
|
||||
_rope_kernel<DT_CUDA<T>::t><<<gridsize, blocksize>>>( \
|
||||
pos, input, output, size, dim_model, dim_head, hidden_stride, pos_stride);
|
||||
|
||||
#define SWITCH_DTYPE(DTYPE) \
|
||||
switch (DTYPE) { \
|
||||
case 1: \
|
||||
CASE(1) \
|
||||
break; \
|
||||
case 2: \
|
||||
CASE(2) \
|
||||
break; \
|
||||
case 3: \
|
||||
CASE(3) \
|
||||
break; \
|
||||
case 4: \
|
||||
CASE(4) \
|
||||
break; \
|
||||
case 5: \
|
||||
CASE(5) \
|
||||
break; \
|
||||
case 6: \
|
||||
CASE(6) \
|
||||
break; \
|
||||
case 7: \
|
||||
CASE(7) \
|
||||
break; \
|
||||
case 10: \
|
||||
CASE(10) \
|
||||
break; \
|
||||
case 11: \
|
||||
CASE(11) \
|
||||
break; \
|
||||
case 12: \
|
||||
CASE(12) \
|
||||
break; \
|
||||
case 13: \
|
||||
CASE(13) \
|
||||
break; \
|
||||
case 16: \
|
||||
CASE(16) \
|
||||
break; \
|
||||
default: \
|
||||
IT_TODO_HALT(); \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void rope_kernel(int dType, int * pos, void *input, void *output, int size, int dim_model, int dim_head, int hidden_stride, int pos_stride) {
|
||||
dim3 blocksize = dim3(1024,1,1);
|
||||
dim3 gridsize = dim3(1, 1, 4);
|
||||
SWITCH_DTYPE(dType)
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -182,6 +182,7 @@ REGISTER_KERNEL(Device::CUDA, OpType::Tanh, TanhCudnn, "Tanh_CUDA");
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Abs, UnaryCuda, "Abs_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, UnaryCuda, "Sqrt_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Silu, UnaryCuda, "Silu_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA");
|
||||
|
||||
|
|
|
@ -103,6 +103,17 @@ __global__ void _gelu_kernel(T *input, T *output, size_t n) {
|
|||
output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _silu_kernel(T *input, T *output, size_t n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
float x = input[i];
|
||||
output[i] = x / (1.0 + expf(-x));;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _erf_kernel(T *input, T *output, size_t n) {
|
||||
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
@ -190,6 +201,14 @@ template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
|
|||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_gelu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
}
|
||||
|
||||
template <typename T> void silu_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_silu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
|
||||
}
|
||||
|
||||
template <typename T> void erf_kernel(T *input, T *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
|
@ -267,6 +286,12 @@ void unary_kernel(const Operator &_op) {
|
|||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
} else if (op->getOpType() == OpType::Silu) {
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
silu_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
} else if (op->getOpType() == OpType::Neg) {
|
||||
if (_op->getDType() == DataType::Float32) {
|
||||
neg_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
#include "operators/rope.h"
|
||||
|
||||
namespace infini {
|
||||
RoPEObj::RoPEObj(GraphObj *graph, Tensor pos, Tensor input, Tensor output)
|
||||
: OperatorObj(OpType::RoPE, {pos, input}, {output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> RoPEObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[1];
|
||||
auto input_dim = A->getDims();
|
||||
auto output_dim = input_dim;
|
||||
return {{output_dim}};
|
||||
}
|
||||
|
||||
std::string RoPEObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << type.toString() << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> RoPEObj::getWorkloadVector() const {
|
||||
vector<int> ret{type.underlying()};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> RoPEObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,36 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/rope.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(RoPE, Cuda) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
auto input = gCuda->addTensor({1, 1, 32}, DataType::Float32);
|
||||
auto position_id_d = gCuda->addTensor({1, 1}, DataType::UInt32);
|
||||
auto output = gCuda->addTensor({1, 1, 32}, DataType::Float32);
|
||||
|
||||
auto op = gCuda->addOpWithOutputs<RoPEObj>(position_id_d, input, output);
|
||||
gCuda->dataMalloc();
|
||||
|
||||
input->setData(OneGenerator());
|
||||
position_id_d->setData(OneGenerator());
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutputs()[0]);
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773, 1.381773,
|
||||
1.381773, 1.381773, 1.381773, 1.381773}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -70,6 +70,7 @@ void testCast(const std::function<void(void *, size_t, DataType)> &generator,
|
|||
|
||||
TEST(cuDNN_Unary, run) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SiluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
|
|
Loading…
Reference in New Issue