forked from jiuyuan/InfiniTensor
【Hackathon No.108】Add Gelu operator, ffi, kernel for cpu and gpu. (#148)
feat: Add Gelu kernel, operator, ffi.
This commit is contained in:
parent
7600fe688c
commit
7f16fa353e
|
@ -45,6 +45,7 @@ class GraphHandlerObj {
|
|||
Tensor max(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor relu(Tensor x, Tensor y);
|
||||
Tensor gelu(Tensor x, Tensor y);
|
||||
Tensor sigmoid(Tensor x, Tensor y);
|
||||
Tensor tanh(Tensor x, Tensor y);
|
||||
Tensor erf(Tensor x, Tensor y);
|
||||
|
|
|
@ -73,6 +73,7 @@ struct OpType {
|
|||
GatherElements,
|
||||
GatherND,
|
||||
Gemm,
|
||||
Gelu, // Unary
|
||||
GlobalAveragePool, // GlobalPool
|
||||
GlobalLpPool, // GlobalPool
|
||||
GlobalMaxPool, // GlobalPool
|
||||
|
|
|
@ -10,6 +10,7 @@ void tanh_kernel(float *input, float *output, size_t num);
|
|||
void abs_kernel(float *input, float *output, size_t num);
|
||||
void sqrt_kernel(float *input, float *output, size_t num);
|
||||
void neg_kernel(float *input, float *output, size_t num);
|
||||
void gelu_kernel(float *input, float *output, size_t num);
|
||||
void erf_kernel(float *input, float *output, size_t num);
|
||||
|
||||
void unary_kernel(const Operator &_op) {
|
||||
|
@ -30,6 +31,8 @@ void unary_kernel(const Operator &_op) {
|
|||
abs_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Sqrt)
|
||||
sqrt_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Gelu)
|
||||
gelu_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Neg)
|
||||
neg_kernel(inputData, outputData, num);
|
||||
else if (op->getOpType() == OpType::Erf)
|
||||
|
|
|
@ -258,6 +258,7 @@ class LogObj : public OperatorObj {
|
|||
};
|
||||
|
||||
DEFINE_UNARY_OBJ(Relu, OpType::Relu)
|
||||
DEFINE_UNARY_OBJ(Gelu, OpType::Gelu)
|
||||
DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
|
||||
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
||||
// DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
||||
|
|
|
@ -374,6 +374,11 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Gelu":
|
||||
tensors[node.output[0]] = self.handler.gelu(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Sigmoid":
|
||||
tensors[node.output[0]] = self.handler.sigmoid(
|
||||
tensors[node.input[0]],
|
||||
|
@ -913,6 +918,7 @@ class OnnxStub:
|
|||
backend.OpTypeId.Div,
|
||||
backend.OpTypeId.Pow,
|
||||
backend.OpTypeId.Relu,
|
||||
backend.OpTypeId.Gelu,
|
||||
backend.OpTypeId.Sigmoid,
|
||||
backend.OpTypeId.Tanh,
|
||||
backend.OpTypeId.Softmax,
|
||||
|
|
|
@ -208,6 +208,14 @@ class TestStringMethods(unittest.TestCase):
|
|||
relu = make_node("Relu", ["x"], ["y"], name="relu")
|
||||
make_and_import_model(make_graph([relu], "relu", [x], [y]))
|
||||
|
||||
'''Gelu operator is not supported by onnx 14.1 currently.'''
|
||||
def test_gelu(self):
|
||||
pass
|
||||
# x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
# y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
# gelu = make_node("Gelu", ["x"], ["y"], name="gelu")
|
||||
# make_and_import_model(make_graph([gelu], "gelu", [x], [y]))
|
||||
|
||||
def test_erf(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
|
|
@ -155,6 +155,7 @@ DEFINE_ELEMENT_WISE_METHOD(max, Maximum)
|
|||
}
|
||||
|
||||
DEFINE_UNARY_METHOD(relu, Relu)
|
||||
DEFINE_UNARY_METHOD(gelu, Gelu)
|
||||
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||
DEFINE_UNARY_METHOD(abs, Abs)
|
||||
|
|
|
@ -142,6 +142,7 @@ const char *OpType::toString() const {
|
|||
CASE(ReduceSum);
|
||||
CASE(ReduceSumSquare);
|
||||
CASE(Relu);
|
||||
CASE(Gelu);
|
||||
CASE(Reshape);
|
||||
CASE(Resize);
|
||||
CASE(ReverseSequence);
|
||||
|
@ -234,7 +235,7 @@ bool OpType::isUnary() const {
|
|||
static const std::unordered_set<decltype(type)> set{
|
||||
Abs, Acos, Acosh, Asin, Asinh, Atan, Atanh, Cast, Ceil,
|
||||
Clip, Cos, Cosh, Erf, Exp, Floor, Log, Neg, Not,
|
||||
Relu, Round, Sigmoid, Sin, Sinh, Sqrt, Tan, Tanh,
|
||||
Relu, Gelu, Round, Sigmoid, Sin, Sinh, Sqrt, Tan, Tanh,
|
||||
};
|
||||
|
||||
return set.find(type) != set.end();
|
||||
|
|
|
@ -92,6 +92,7 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, BatchNormalization)
|
||||
.VALUE(OpType, Softmax)
|
||||
.VALUE(OpType, Relu)
|
||||
.VALUE(OpType, Gelu)
|
||||
.VALUE(OpType, PRelu)
|
||||
.VALUE(OpType, Sigmoid)
|
||||
.VALUE(OpType, Tanh)
|
||||
|
@ -440,6 +441,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("min", &Handler::min, policy::move)
|
||||
.def("max", &Handler::max, policy::move)
|
||||
.def("relu", &Handler::relu, policy::move)
|
||||
.def("gelu", &Handler::gelu, policy::move)
|
||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||
.def("tanh", &Handler::tanh, policy::move)
|
||||
.def("softmax", &Handler::softmax, policy::move)
|
||||
|
|
|
@ -60,6 +60,12 @@ template <typename T> class NaiveSqrt : public NativeUnary<T> {
|
|||
T doCompute(T val) const override { return std::sqrt(val); }
|
||||
};
|
||||
|
||||
template <typename T> class NaiveGelu : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveErf : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::erf(val); }
|
||||
};
|
||||
|
@ -91,6 +97,10 @@ REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32,
|
|||
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>,
|
||||
"reluNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::UInt32, NaiveGelu<float>,
|
||||
"geluNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::Float32, NaiveGelu<float>,
|
||||
"geluNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
|
||||
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
|
||||
|
|
|
@ -140,6 +140,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
|||
"Abs_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda,
|
||||
"Sqrt_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Gelu, DataType::Float32, UnaryCuda,
|
||||
"Gelu_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Neg, DataType::Float32, UnaryCuda,
|
||||
"Neg_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda,
|
||||
|
|
|
@ -66,6 +66,15 @@ __global__ void _sqrt_kernel(float *input, float *output, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void _gelu_kernel(float *input, float *output, size_t n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
float x = input[i];
|
||||
output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _erf_kernel(float *input, float *output, size_t n) {
|
||||
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
size_t stride = blockDim.x * gridDim.x;
|
||||
|
@ -121,6 +130,12 @@ void sqrt_kernel(float *input, float *output, size_t num) {
|
|||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
}
|
||||
void gelu_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_gelu_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||
}
|
||||
void erf_kernel(float *input, float *output, size_t num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
|
|
|
@ -52,6 +52,10 @@ TEST(cuDNN_Unary, run) {
|
|||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
|
||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
|
||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{2, 3, 4, 5, 6});
|
||||
|
||||
testUnary<GeluObj>(IncrementalGenerator(), Shape{1});
|
||||
testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2});
|
||||
testUnary<GeluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
using ExpectOutput = vector<float>;
|
||||
TEST(Unary, ShapeInference) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({2}, DataType::Float32);
|
||||
auto op = g->addOp<GeluObj>(i0, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue