From 785853b0a3cc8cf47fe9650932d832595204ce2a Mon Sep 17 00:00:00 2001 From: PanZezhong1725 <141193946+PanZezhong1725@users.noreply.github.com> Date: Mon, 9 Oct 2023 09:36:55 +0800 Subject: [PATCH] Add erf kernel for cpu and gpu (#147) Co-authored-by: panzezhong@qiyuanlab.com --- include/cuda/cuda_unary.h | 3 +++ src/kernels/cpu/unary.cc | 6 ++++++ src/kernels/cuda/unary.cc | 2 ++ src/kernels/cuda/unary.cu | 14 ++++++++++++++ test/kernels/cuda/test_cuda_unary.cc | 1 + 5 files changed, 26 insertions(+) diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index 99f73009..0f26c2e3 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -10,6 +10,7 @@ void sigmoid_kernel(float *input, float *output, int num); void tanh_kernel(float *input, float *output, int num); void abs_kernel(float *input, float *output, int num); void sqrt_kernel(float *input, float *output, int num); +void erf_kernel(float *input, float *output, int num); void unary_kernel(const Operator &_op) { auto op = as(_op); @@ -29,6 +30,8 @@ void unary_kernel(const Operator &_op) { abs_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Sqrt) sqrt_kernel(inputData, outputData, num); + else if (op->getOpType() == OpType::Erf) + erf_kernel(inputData, outputData, num); else IT_TODO_HALT(); } diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 755e0a93..e559c909 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -60,6 +60,10 @@ template class NaiveSqrt : public NativeUnary { T doCompute(T val) const override { return std::sqrt(val); } }; +template class NaiveErf : public NativeUnary { + T doCompute(T val) const override { return std::erf(val); } +}; + template class Clip : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { @@ -97,6 +101,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs, "absNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt, "sqrtNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf, + "erfNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, NaiveSoftmax, "softmaxNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index abc8b0bc..897e2c77 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -140,6 +140,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, "Abs_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda, "Sqrt_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda, + "Erf_CUDA_Float32"); // REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda, // "Softmax_CUDA_Float32"); diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index b79bd53f..695762b4 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -66,6 +66,14 @@ __global__ void _sqrt_kernel(float *input, float *output, int n) { } } +__global__ void _erf_kernel(float *input, float *output, int n) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) { + output[i] = erf(input[i]); + } +} + namespace infini { void softmax_kernel(float *input, float *output, int num) { @@ -104,4 +112,10 @@ void sqrt_kernel(float *input, float *output, int num) { int gridsize = (num + block_work_size() - 1) / block_work_size(); _sqrt_kernel<<>>(input, output, num); } +void erf_kernel(float *input, float *output, int num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _erf_kernel<<>>(input, output, num); +} }; // namespace infini diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 78eb95aa..5d9f24ec 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -46,6 +46,7 @@ TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); // more shapes testUnary(IncrementalGenerator(), Shape{13}); testUnary(IncrementalGenerator(), Shape{4, 3});