add unittest of silu kernel

This commit is contained in:
xiaonans 2024-01-30 10:40:13 +08:00
parent 030e5ca9c1
commit 956ce37458
2 changed files with 9 additions and 0 deletions

View File

@ -47,6 +47,10 @@ class NativeUnary : public CpuKernelWithoutConfig {
return 0.5 * val * (1 + std::erf(val / std::sqrt(2))); return 0.5 * val * (1 + std::erf(val / std::sqrt(2)));
} }
template <typename T> static T siluCompute(T val) {
return val / (1 + pow(E_CONSTANT, -val));
}
template <typename T> static T erfCompute(T val) { return std::erf(val); } template <typename T> static T erfCompute(T val) { return std::erf(val); }
template <typename T> static T aCosCompute(T val) { return std::acos(val); } template <typename T> static T aCosCompute(T val) { return std::acos(val); }
@ -84,6 +88,9 @@ class NativeUnary : public CpuKernelWithoutConfig {
case OpType::Gelu: case OpType::Gelu:
_doCompute = geluCompute<T>; _doCompute = geluCompute<T>;
break; break;
case OpType::Silu:
_doCompute = siluCompute<T>;
break;
case OpType::Sigmoid: case OpType::Sigmoid:
_doCompute = sigmoidCompute<T>; _doCompute = sigmoidCompute<T>;
break; break;
@ -289,6 +296,7 @@ class Log : public CpuKernelWithoutConfig {
REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Relu, NativeUnary, "reluNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Gelu, NativeUnary, "geluNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Silu, NativeUnary, "siluNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU"); REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, NativeUnary, "sigmoidNaive_CPU");
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary, REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, NativeUnary,
"hardSigmoidNaive_CPU"); "hardSigmoidNaive_CPU");

View File

@ -70,6 +70,7 @@ void testCast(const std::function<void(void *, size_t, DataType)> &generator,
TEST(cuDNN_Unary, run) { TEST(cuDNN_Unary, run) {
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SiluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});