diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 024d720a..9e7cead0 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -47,6 +47,10 @@ class NativeUnary : public CpuKernelWithoutConfig { return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); } + template static T siluCompute(T val) { + return val / (1 + pow(E_CONSTANT, -val)); + } + template static T erfCompute(T val) { return std::erf(val); } template static T aCosCompute(T val) { return std::acos(val); } @@ -84,6 +88,9 @@ class NativeUnary : public CpuKernelWithoutConfig { case OpType::Gelu: _doCompute = geluCompute; break; + case OpType::Silu: + _doCompute = siluCompute; + break; case OpType::Sigmoid: _doCompute = sigmoidCompute; break; @@ -289,6 +296,7 @@ class Log : public CpuKernelWithoutConfig { REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU"); +REGISTER_KERNEL(Device::CPU, OpType::Silu, NativeUnary, "siluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary, "hardSigmoidNaive_CPU"); diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index fd407dfd..27ce90f1 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -70,6 +70,7 @@ void testCast(const std::function &generator, TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3});