diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 1c79f51d..975a78bf 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -45,6 +45,7 @@ class GraphHandlerObj { Tensor max(Tensor a, Tensor b, Tensor c); Tensor relu(Tensor x, Tensor y); + Tensor gelu(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y); Tensor tanh(Tensor x, Tensor y); Tensor erf(Tensor x, Tensor y); diff --git a/include/core/op_type.h b/include/core/op_type.h index e0146c5f..82439650 100644 --- a/include/core/op_type.h +++ b/include/core/op_type.h @@ -73,6 +73,7 @@ struct OpType { GatherElements, GatherND, Gemm, + Gelu, // Unary GlobalAveragePool, // GlobalPool GlobalLpPool, // GlobalPool GlobalMaxPool, // GlobalPool diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index c538682a..c839abc6 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -10,6 +10,7 @@ void tanh_kernel(float *input, float *output, size_t num); void abs_kernel(float *input, float *output, size_t num); void sqrt_kernel(float *input, float *output, size_t num); void neg_kernel(float *input, float *output, size_t num); +void gelu_kernel(float *input, float *output, size_t num); void erf_kernel(float *input, float *output, size_t num); void unary_kernel(const Operator &_op) { @@ -30,6 +31,8 @@ void unary_kernel(const Operator &_op) { abs_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Sqrt) sqrt_kernel(inputData, outputData, num); + else if (op->getOpType() == OpType::Gelu) + gelu_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Neg) neg_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Erf) diff --git a/include/operators/unary.h b/include/operators/unary.h index 8a3d9704..8349993c 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -258,6 +258,7 @@ class LogObj : public OperatorObj { }; DEFINE_UNARY_OBJ(Relu, OpType::Relu) +DEFINE_UNARY_OBJ(Gelu, OpType::Gelu) DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid) DEFINE_UNARY_OBJ(Tanh, OpType::Tanh) // DEFINE_UNARY_OBJ(Softmax, OpType::Softmax) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 96d4778d..af1e1f95 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -374,6 +374,11 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "Gelu": + tensors[node.output[0]] = self.handler.gelu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "Sigmoid": tensors[node.output[0]] = self.handler.sigmoid( tensors[node.input[0]], @@ -913,6 +918,7 @@ class OnnxStub: backend.OpTypeId.Div, backend.OpTypeId.Pow, backend.OpTypeId.Relu, + backend.OpTypeId.Gelu, backend.OpTypeId.Sigmoid, backend.OpTypeId.Tanh, backend.OpTypeId.Softmax, diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 2d614b48..3420fa4f 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -208,6 +208,14 @@ class TestStringMethods(unittest.TestCase): relu = make_node("Relu", ["x"], ["y"], name="relu") make_and_import_model(make_graph([relu], "relu", [x], [y])) + '''Gelu operator is not supported by onnx 14.1 currently.''' + def test_gelu(self): + pass + # x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + # y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + # gelu = make_node("Gelu", ["x"], ["y"], name="gelu") + # make_and_import_model(make_graph([gelu], "gelu", [x], [y])) + def test_erf(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index bbd73e10..6255c8fd 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -155,6 +155,7 @@ DEFINE_ELEMENT_WISE_METHOD(max, Maximum) } DEFINE_UNARY_METHOD(relu, Relu) +DEFINE_UNARY_METHOD(gelu, Gelu) DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(tanh, Tanh) DEFINE_UNARY_METHOD(abs, Abs) diff --git a/src/core/op_type.cc b/src/core/op_type.cc index 38122bf9..5932513f 100644 --- a/src/core/op_type.cc +++ b/src/core/op_type.cc @@ -142,6 +142,7 @@ const char *OpType::toString() const { CASE(ReduceSum); CASE(ReduceSumSquare); CASE(Relu); + CASE(Gelu); CASE(Reshape); CASE(Resize); CASE(ReverseSequence); @@ -232,9 +233,9 @@ const char *OpType::toString() const { bool OpType::isUnary() const { static const std::unordered_set set{ - Abs, Acos, Acosh, Asin, Asinh, Atan, Atanh, Cast, Ceil, - Clip, Cos, Cosh, Erf, Exp, Floor, Log, Neg, Not, - Relu, Round, Sigmoid, Sin, Sinh, Sqrt, Tan, Tanh, + Abs, Acos, Acosh, Asin, Asinh, Atan, Atanh, Cast, Ceil, + Clip, Cos, Cosh, Erf, Exp, Floor, Log, Neg, Not, + Relu, Gelu, Round, Sigmoid, Sin, Sinh, Sqrt, Tan, Tanh, }; return set.find(type) != set.end(); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index b515164f..27a1ba81 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -92,6 +92,7 @@ void export_values(py::module &m) { .VALUE(OpType, BatchNormalization) .VALUE(OpType, Softmax) .VALUE(OpType, Relu) + .VALUE(OpType, Gelu) .VALUE(OpType, PRelu) .VALUE(OpType, Sigmoid) .VALUE(OpType, Tanh) @@ -440,6 +441,7 @@ void init_graph_builder(py::module &m) { .def("min", &Handler::min, policy::move) .def("max", &Handler::max, policy::move) .def("relu", &Handler::relu, policy::move) + .def("gelu", &Handler::gelu, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move) .def("tanh", &Handler::tanh, policy::move) .def("softmax", &Handler::softmax, policy::move) diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 15025115..ec7497c3 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -60,6 +60,12 @@ template class NaiveSqrt : public NativeUnary { T doCompute(T val) const override { return std::sqrt(val); } }; +template class NaiveGelu : public NativeUnary { + T doCompute(T val) const override { + return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); + } +}; + template class NaiveErf : public NativeUnary { T doCompute(T val) const override { return std::erf(val); } }; @@ -91,6 +97,10 @@ REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32, NaiveRelu, "reluNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu, "reluNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::UInt32, NaiveGelu, + "geluNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Gelu, DataType::Float32, NaiveGelu, + "geluNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32, NaiveSigmoid, "sigmoidNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32, diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index cb53bd80..48f6daaa 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -140,6 +140,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, "Abs_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda, "Sqrt_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Gelu, DataType::Float32, UnaryCuda, + "Gelu_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Neg, DataType::Float32, UnaryCuda, "Neg_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda, diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 061ac63d..2267e6eb 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -66,6 +66,15 @@ __global__ void _sqrt_kernel(float *input, float *output, size_t n) { } } +__global__ void _gelu_kernel(float *input, float *output, size_t n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + float x = input[i]; + output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f))); + } +} + __global__ void _erf_kernel(float *input, float *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; @@ -121,6 +130,12 @@ void sqrt_kernel(float *input, float *output, size_t num) { int gridsize = (num + block_work_size() - 1) / block_work_size(); _sqrt_kernel<<>>(input, output, num); } +void gelu_kernel(float *input, float *output, size_t num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _gelu_kernel<<>>(input, output, num); +} void erf_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 2f828fab..09a2255e 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -52,6 +52,10 @@ TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{13}); testUnary(IncrementalGenerator(), Shape{4, 3}); testUnary(IncrementalGenerator(), Shape{2, 3, 4, 5, 6}); + + testUnary(IncrementalGenerator(), Shape{1}); + testUnary(IncrementalGenerator(), Shape{1, 2}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); } } // namespace infini diff --git a/test/operators/test_unary.cc b/test/operators/test_unary.cc new file mode 100644 index 00000000..911d815e --- /dev/null +++ b/test/operators/test_unary.cc @@ -0,0 +1,21 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/unary.h" + +#include "test.h" + +namespace infini { + +using ExpectOutput = vector; +TEST(Unary, ShapeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({2}, DataType::Float32); + auto op = g->addOp(i0, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2})); + } +} + +} // namespace infini