diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 975a78bf..6c670227 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -47,6 +47,8 @@ class GraphHandlerObj { Tensor relu(Tensor x, Tensor y); Tensor gelu(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y); + Tensor hardSigmoid(Tensor x, Tensor y); + Tensor hardSwish(Tensor x, Tensor y); Tensor tanh(Tensor x, Tensor y); Tensor erf(Tensor x, Tensor y); Tensor softmax(Tensor x, Tensor y, int axis); diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index c839abc6..31a39951 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -12,6 +12,8 @@ void sqrt_kernel(float *input, float *output, size_t num); void neg_kernel(float *input, float *output, size_t num); void gelu_kernel(float *input, float *output, size_t num); void erf_kernel(float *input, float *output, size_t num); +void hard_sigmoid_kernel(float *input, float *output, size_t num); +void hard_swish_kernel(float *input, float *output, size_t num); void unary_kernel(const Operator &_op) { auto op = as(_op); @@ -25,6 +27,10 @@ void unary_kernel(const Operator &_op) { relu_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Sigmoid) sigmoid_kernel(inputData, outputData, num); + else if (op->getOpType() == OpType::HardSigmoid) + hard_sigmoid_kernel(inputData, outputData, num); + else if (op->getOpType() == OpType::HardSwish) + hard_swish_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Tanh) tanh_kernel(inputData, outputData, num); else if (op->getOpType() == OpType::Abs) diff --git a/include/operators/unary.h b/include/operators/unary.h index 8349993c..0bbe314c 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -263,6 +263,8 @@ DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid) DEFINE_UNARY_OBJ(Tanh, OpType::Tanh) // DEFINE_UNARY_OBJ(Softmax, OpType::Softmax) DEFINE_UNARY_OBJ(Abs, OpType::Abs) +DEFINE_UNARY_OBJ(HardSigmoid, OpType::HardSigmoid) +DEFINE_UNARY_OBJ(HardSwish, OpType::HardSwish) DEFINE_UNARY_OBJ(Sin, OpType::Sin) DEFINE_UNARY_OBJ(Cos, OpType::Cos) diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 103af8e4..659a9802 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -395,6 +395,16 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), ) + elif node.op_type == "HardSigmoid": + tensors[node.output[0]] = self.handler.hardSigmoid( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "HardSwish": + tensors[node.output[0]] = self.handler.hardSwish( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) elif node.op_type == "Tanh": tensors[node.output[0]] = self.handler.tanh( tensors[node.input[0]], @@ -931,6 +941,8 @@ class OnnxStub: backend.OpTypeId.Relu, backend.OpTypeId.Gelu, backend.OpTypeId.Sigmoid, + backend.OpTypeId.HardSigmoid, + backend.OpTypeId.HardSwish, backend.OpTypeId.Tanh, backend.OpTypeId.Softmax, backend.OpTypeId.Abs, diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 3420fa4f..f80ad220 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -239,6 +239,18 @@ class TestStringMethods(unittest.TestCase): y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) tanh = make_node("Tanh", ["x"], ["y"], name="tanh") make_and_import_model(make_graph([tanh], "tanh", [x], [y])) + + def test_hard_sigmoid(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + hardSigmoid = make_node("HardSigmoid", ["x"], ["y"], name="hardSigmoid") + make_and_import_model(make_graph([hardSigmoid], "hardSigmoid", [x], [y])) + + def test_hard_swish(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + hardSwish = make_node("HardSwish", ["x"], ["y"], name="hardSwish") + make_and_import_model(make_graph([hardSwish], "hardSwish", [x], [y])) def test_softmax(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 6255c8fd..225fae09 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -158,6 +158,8 @@ DEFINE_UNARY_METHOD(relu, Relu) DEFINE_UNARY_METHOD(gelu, Gelu) DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(tanh, Tanh) +DEFINE_UNARY_METHOD(hardSigmoid, HardSigmoid) +DEFINE_UNARY_METHOD(hardSwish, HardSwish) DEFINE_UNARY_METHOD(abs, Abs) DEFINE_UNARY_METHOD(sqrt, Sqrt) DEFINE_UNARY_METHOD(neg, Neg) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 27a1ba81..8ac563b6 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -96,6 +96,8 @@ void export_values(py::module &m) { .VALUE(OpType, PRelu) .VALUE(OpType, Sigmoid) .VALUE(OpType, Tanh) + .VALUE(OpType, HardSigmoid) + .VALUE(OpType, HardSwish) .VALUE(OpType, Abs) .VALUE(OpType, Resize) .VALUE(OpType, Dropout) @@ -444,6 +446,8 @@ void init_graph_builder(py::module &m) { .def("gelu", &Handler::gelu, policy::move) .def("sigmoid", &Handler::sigmoid, policy::move) .def("tanh", &Handler::tanh, policy::move) + .def("hardSigmoid", &Handler::hardSigmoid, policy::move) + .def("hardSwish", &Handler::hardSwish, policy::move) .def("softmax", &Handler::softmax, policy::move) .def("abs", &Handler::abs, policy::move) .def("sqrt", &Handler::sqrt, policy::move) diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index ec7497c3..ed2f30c7 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -46,6 +46,17 @@ template class NaiveSigmoid : public NativeUnary { return 1 / (1 + pow(E_CONSTANT, -val)); } }; +template class NaiveHardSigmoid : public NativeUnary { + T doCompute(T val) const override { + return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5))); + } +}; +template class NaiveHardSwish : public NativeUnary { + T doCompute(T val) const override { + return val * + std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5))); + } +}; template class NaiveTanh : public NativeUnary { T doCompute(T val) const override { return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) / @@ -105,6 +116,10 @@ REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32, NaiveSigmoid, "sigmoidNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32, NaiveSigmoid, "sigmoidNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, DataType::Float32, + NaiveHardSigmoid, "hardSigmoidNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::HardSwish, DataType::Float32, + NaiveHardSwish, "hardSwishNaive_CPU_float32"); REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32, NaiveTanh, "tanhNaive_CPU_uint32"); REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh, diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index 48f6daaa..a27d4ac4 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -134,6 +134,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn, "Relu_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn, "Sigmoid_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, DataType::Float32, UnaryCuda, + "Hard_Sigmoid_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::HardSwish, DataType::Float32, UnaryCuda, + "Hard_Swish_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn, "Tanh_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 2267e6eb..22e2e423 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -41,6 +41,23 @@ __global__ void _sigmoid_kernel(float *input, float *output, size_t n) { } } +__global__ void _hard_sigmoid_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { + output[i] = max(0.0f, min(1.0f, 0.2f * input[i] + 0.5f)); + } +} + +__global__ void _hard_swish_kernel(float *input, float *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { + output[i] = + input[i] * max(0.f, min(1.f, (1.f / 6.f) * input[i] + 0.5f)); + } +} + __global__ void _tanh_kernel(float *input, float *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; @@ -112,6 +129,18 @@ void sigmoid_kernel(float *input, float *output, size_t num) { int gridsize = (num + block_work_size() - 1) / block_work_size(); _sigmoid_kernel<<>>(input, output, num); } +void hard_sigmoid_kernel(float *input, float *output, size_t num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _hard_sigmoid_kernel<<>>(input, output, num); +} +void hard_swish_kernel(float *input, float *output, size_t num) { + + int blocksize = block_work_size(); + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _hard_swish_kernel<<>>(input, output, num); +} void tanh_kernel(float *input, float *output, size_t num) { int blocksize = block_work_size(); diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 09a2255e..4a2e5e98 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -45,6 +45,8 @@ TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); + testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3});