forked from jiuyuan/InfiniTensor
Add erf kernel for cpu and gpu (#147)
Co-authored-by: panzezhong@qiyuanlab.com <panzezhong@zezhongpan>
This commit is contained in:
parent
c0ff584e04
commit
785853b0a3
|
@ -10,6 +10,7 @@ void sigmoid_kernel(float *input, float *output, int num);
|
||||||
void tanh_kernel(float *input, float *output, int num);
|
void tanh_kernel(float *input, float *output, int num);
|
||||||
void abs_kernel(float *input, float *output, int num);
|
void abs_kernel(float *input, float *output, int num);
|
||||||
void sqrt_kernel(float *input, float *output, int num);
|
void sqrt_kernel(float *input, float *output, int num);
|
||||||
|
void erf_kernel(float *input, float *output, int num);
|
||||||
|
|
||||||
void unary_kernel(const Operator &_op) {
|
void unary_kernel(const Operator &_op) {
|
||||||
auto op = as<UnaryObj>(_op);
|
auto op = as<UnaryObj>(_op);
|
||||||
|
@ -29,6 +30,8 @@ void unary_kernel(const Operator &_op) {
|
||||||
abs_kernel(inputData, outputData, num);
|
abs_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Sqrt)
|
else if (op->getOpType() == OpType::Sqrt)
|
||||||
sqrt_kernel(inputData, outputData, num);
|
sqrt_kernel(inputData, outputData, num);
|
||||||
|
else if (op->getOpType() == OpType::Erf)
|
||||||
|
erf_kernel(inputData, outputData, num);
|
||||||
else
|
else
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,10 @@ template <typename T> class NaiveSqrt : public NativeUnary<T> {
|
||||||
T doCompute(T val) const override { return std::sqrt(val); }
|
T doCompute(T val) const override { return std::sqrt(val); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T> class NaiveErf : public NativeUnary<T> {
|
||||||
|
T doCompute(T val) const override { return std::erf(val); }
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T> class Clip : public CpuKernelWithoutConfig {
|
template <typename T> class Clip : public CpuKernelWithoutConfig {
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *context) const override {
|
const RuntimeObj *context) const override {
|
||||||
|
@ -97,6 +101,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
|
||||||
"absNaive_CPU_float32");
|
"absNaive_CPU_float32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
|
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
|
||||||
"sqrtNaive_CPU_float32");
|
"sqrtNaive_CPU_float32");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Erf, DataType::Float32, NaiveErf<float>,
|
||||||
|
"erfNaive_CPU_float32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
|
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
|
||||||
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
|
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
|
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
|
||||||
|
|
|
@ -140,6 +140,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
||||||
"Abs_CUDA_Float32");
|
"Abs_CUDA_Float32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda,
|
||||||
"Sqrt_CUDA_Float32");
|
"Sqrt_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Erf, DataType::Float32, UnaryCuda,
|
||||||
|
"Erf_CUDA_Float32");
|
||||||
|
|
||||||
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
|
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
|
||||||
// "Softmax_CUDA_Float32");
|
// "Softmax_CUDA_Float32");
|
||||||
|
|
|
@ -66,6 +66,14 @@ __global__ void _sqrt_kernel(float *input, float *output, int n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void _erf_kernel(float *input, float *output, int n) {
|
||||||
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
for (int i = index; i < n; i += stride) {
|
||||||
|
output[i] = erf(input[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void softmax_kernel(float *input, float *output, int num) {
|
void softmax_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
|
@ -104,4 +112,10 @@ void sqrt_kernel(float *input, float *output, int num) {
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
|
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
|
void erf_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
|
int blocksize = block_work_size();
|
||||||
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
|
_erf_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
|
}
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -46,6 +46,7 @@ TEST(cuDNN_Unary, run) {
|
||||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
|
testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
// more shapes
|
// more shapes
|
||||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
|
||||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
|
||||||
|
|
Loading…
Reference in New Issue