Add erf kernel for cpu and gpu (#147)

Co-authored-by: panzezhong@qiyuanlab.com <panzezhong@zezhongpan>
This commit is contained in:
PanZezhong1725 2023-10-09 09:36:55 +08:00 committed by GitHub
parent c0ff584e04
commit 785853b0a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 0 deletions

View File

@ -10,6 +10,7 @@ void sigmoid_kernel(float *input, float *output, int num);
void tanh_kernel(float *input, float *output, int num); void tanh_kernel(float *input, float *output, int num);
void abs_kernel(float *input, float *output, int num); void abs_kernel(float *input, float *output, int num);
void sqrt_kernel(float *input, float *output, int num); void sqrt_kernel(float *input, float *output, int num);
void erf_kernel(float *input, float *output, int num);
void unary_kernel(const Operator &_op) { void unary_kernel(const Operator &_op) {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
@ -29,6 +30,8 @@ void unary_kernel(const Operator &_op) {
abs_kernel(inputData, outputData, num); abs_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sqrt) else if (op->getOpType() == OpType::Sqrt)
sqrt_kernel(inputData, outputData, num); sqrt_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Erf)
erf_kernel(inputData, outputData, num);
else else
IT_TODO_HALT(); IT_TODO_HALT();
} }

View File

@ -60,6 +60,10 @@ template <typename T> class NaiveSqrt : public NativeUnary<T> {
T doCompute(T val) const override { return std::sqrt(val); } T doCompute(T val) const override { return std::sqrt(val); }
}; };
template <typename T> class NaiveErf : public NativeUnary<T> {
T doCompute(T val) const override { return std::erf(val); }
};
template <typename T> class Clip : public CpuKernelWithoutConfig { template <typename T> class Clip : public CpuKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *context) const override { const RuntimeObj *context) const override {
@ -97,6 +101,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
"absNaive_CPU_float32"); "absNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>, REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
"sqrtNaive_CPU_float32"); "sqrtNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>,
"erfNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32"); NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,

View File

@ -140,6 +140,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
"Abs_CUDA_Float32"); "Abs_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda, REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda,
"Sqrt_CUDA_Float32"); "Sqrt_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda,
"Erf_CUDA_Float32");
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda, // REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
// "Softmax_CUDA_Float32"); // "Softmax_CUDA_Float32");

View File

@ -66,6 +66,14 @@ __global__ void _sqrt_kernel(float *input, float *output, int n) {
} }
} }
__global__ void _erf_kernel(float *input, float *output, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
output[i] = erf(input[i]);
}
}
namespace infini { namespace infini {
void softmax_kernel(float *input, float *output, int num) { void softmax_kernel(float *input, float *output, int num) {
@ -104,4 +112,10 @@ void sqrt_kernel(float *input, float *output, int num) {
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num); _sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
} }
void erf_kernel(float *input, float *output, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_erf_kernel<<<gridsize, blocksize>>>(input, output, num);
}
}; // namespace infini }; // namespace infini

View File

@ -46,6 +46,7 @@ TEST(cuDNN_Unary, run) {
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
// more shapes // more shapes
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13}); testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3}); testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});