forked from jiuyuan/InfiniTensor
Add HardSigmoid and HardSwish (#156)
* Add HardSigmoid and HardSwish * fix format
This commit is contained in:
parent
1151101fb9
commit
ed3034f878
|
@ -47,6 +47,8 @@ class GraphHandlerObj {
|
||||||
Tensor relu(Tensor x, Tensor y);
|
Tensor relu(Tensor x, Tensor y);
|
||||||
Tensor gelu(Tensor x, Tensor y);
|
Tensor gelu(Tensor x, Tensor y);
|
||||||
Tensor sigmoid(Tensor x, Tensor y);
|
Tensor sigmoid(Tensor x, Tensor y);
|
||||||
|
Tensor hardSigmoid(Tensor x, Tensor y);
|
||||||
|
Tensor hardSwish(Tensor x, Tensor y);
|
||||||
Tensor tanh(Tensor x, Tensor y);
|
Tensor tanh(Tensor x, Tensor y);
|
||||||
Tensor erf(Tensor x, Tensor y);
|
Tensor erf(Tensor x, Tensor y);
|
||||||
Tensor softmax(Tensor x, Tensor y, int axis);
|
Tensor softmax(Tensor x, Tensor y, int axis);
|
||||||
|
|
|
@ -12,6 +12,8 @@ void sqrt_kernel(float *input, float *output, size_t num);
|
||||||
void neg_kernel(float *input, float *output, size_t num);
|
void neg_kernel(float *input, float *output, size_t num);
|
||||||
void gelu_kernel(float *input, float *output, size_t num);
|
void gelu_kernel(float *input, float *output, size_t num);
|
||||||
void erf_kernel(float *input, float *output, size_t num);
|
void erf_kernel(float *input, float *output, size_t num);
|
||||||
|
void hard_sigmoid_kernel(float *input, float *output, size_t num);
|
||||||
|
void hard_swish_kernel(float *input, float *output, size_t num);
|
||||||
|
|
||||||
void unary_kernel(const Operator &_op) {
|
void unary_kernel(const Operator &_op) {
|
||||||
auto op = as<UnaryObj>(_op);
|
auto op = as<UnaryObj>(_op);
|
||||||
|
@ -25,6 +27,10 @@ void unary_kernel(const Operator &_op) {
|
||||||
relu_kernel(inputData, outputData, num);
|
relu_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Sigmoid)
|
else if (op->getOpType() == OpType::Sigmoid)
|
||||||
sigmoid_kernel(inputData, outputData, num);
|
sigmoid_kernel(inputData, outputData, num);
|
||||||
|
else if (op->getOpType() == OpType::HardSigmoid)
|
||||||
|
hard_sigmoid_kernel(inputData, outputData, num);
|
||||||
|
else if (op->getOpType() == OpType::HardSwish)
|
||||||
|
hard_swish_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Tanh)
|
else if (op->getOpType() == OpType::Tanh)
|
||||||
tanh_kernel(inputData, outputData, num);
|
tanh_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Abs)
|
else if (op->getOpType() == OpType::Abs)
|
||||||
|
|
|
@ -263,6 +263,8 @@ DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
|
||||||
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
||||||
// DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
// DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
||||||
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
|
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
|
||||||
|
DEFINE_UNARY_OBJ(HardSigmoid, OpType::HardSigmoid)
|
||||||
|
DEFINE_UNARY_OBJ(HardSwish, OpType::HardSwish)
|
||||||
|
|
||||||
DEFINE_UNARY_OBJ(Sin, OpType::Sin)
|
DEFINE_UNARY_OBJ(Sin, OpType::Sin)
|
||||||
DEFINE_UNARY_OBJ(Cos, OpType::Cos)
|
DEFINE_UNARY_OBJ(Cos, OpType::Cos)
|
||||||
|
|
|
@ -395,6 +395,16 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "HardSigmoid":
|
||||||
|
tensors[node.output[0]] = self.handler.hardSigmoid(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
elif node.op_type == "HardSwish":
|
||||||
|
tensors[node.output[0]] = self.handler.hardSwish(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
elif node.op_type == "Tanh":
|
elif node.op_type == "Tanh":
|
||||||
tensors[node.output[0]] = self.handler.tanh(
|
tensors[node.output[0]] = self.handler.tanh(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -931,6 +941,8 @@ class OnnxStub:
|
||||||
backend.OpTypeId.Relu,
|
backend.OpTypeId.Relu,
|
||||||
backend.OpTypeId.Gelu,
|
backend.OpTypeId.Gelu,
|
||||||
backend.OpTypeId.Sigmoid,
|
backend.OpTypeId.Sigmoid,
|
||||||
|
backend.OpTypeId.HardSigmoid,
|
||||||
|
backend.OpTypeId.HardSwish,
|
||||||
backend.OpTypeId.Tanh,
|
backend.OpTypeId.Tanh,
|
||||||
backend.OpTypeId.Softmax,
|
backend.OpTypeId.Softmax,
|
||||||
backend.OpTypeId.Abs,
|
backend.OpTypeId.Abs,
|
||||||
|
|
|
@ -239,6 +239,18 @@ class TestStringMethods(unittest.TestCase):
|
||||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
|
tanh = make_node("Tanh", ["x"], ["y"], name="tanh")
|
||||||
make_and_import_model(make_graph([tanh], "tanh", [x], [y]))
|
make_and_import_model(make_graph([tanh], "tanh", [x], [y]))
|
||||||
|
|
||||||
|
def test_hard_sigmoid(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
hardSigmoid = make_node("HardSigmoid", ["x"], ["y"], name="hardSigmoid")
|
||||||
|
make_and_import_model(make_graph([hardSigmoid], "hardSigmoid", [x], [y]))
|
||||||
|
|
||||||
|
def test_hard_swish(self):
|
||||||
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
hardSwish = make_node("HardSwish", ["x"], ["y"], name="hardSwish")
|
||||||
|
make_and_import_model(make_graph([hardSwish], "hardSwish", [x], [y]))
|
||||||
|
|
||||||
def test_softmax(self):
|
def test_softmax(self):
|
||||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||||
|
|
|
@ -158,6 +158,8 @@ DEFINE_UNARY_METHOD(relu, Relu)
|
||||||
DEFINE_UNARY_METHOD(gelu, Gelu)
|
DEFINE_UNARY_METHOD(gelu, Gelu)
|
||||||
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||||
DEFINE_UNARY_METHOD(tanh, Tanh)
|
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||||
|
DEFINE_UNARY_METHOD(hardSigmoid, HardSigmoid)
|
||||||
|
DEFINE_UNARY_METHOD(hardSwish, HardSwish)
|
||||||
DEFINE_UNARY_METHOD(abs, Abs)
|
DEFINE_UNARY_METHOD(abs, Abs)
|
||||||
DEFINE_UNARY_METHOD(sqrt, Sqrt)
|
DEFINE_UNARY_METHOD(sqrt, Sqrt)
|
||||||
DEFINE_UNARY_METHOD(neg, Neg)
|
DEFINE_UNARY_METHOD(neg, Neg)
|
||||||
|
|
|
@ -96,6 +96,8 @@ void export_values(py::module &m) {
|
||||||
.VALUE(OpType, PRelu)
|
.VALUE(OpType, PRelu)
|
||||||
.VALUE(OpType, Sigmoid)
|
.VALUE(OpType, Sigmoid)
|
||||||
.VALUE(OpType, Tanh)
|
.VALUE(OpType, Tanh)
|
||||||
|
.VALUE(OpType, HardSigmoid)
|
||||||
|
.VALUE(OpType, HardSwish)
|
||||||
.VALUE(OpType, Abs)
|
.VALUE(OpType, Abs)
|
||||||
.VALUE(OpType, Resize)
|
.VALUE(OpType, Resize)
|
||||||
.VALUE(OpType, Dropout)
|
.VALUE(OpType, Dropout)
|
||||||
|
@ -444,6 +446,8 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("gelu", &Handler::gelu, policy::move)
|
.def("gelu", &Handler::gelu, policy::move)
|
||||||
.def("sigmoid", &Handler::sigmoid, policy::move)
|
.def("sigmoid", &Handler::sigmoid, policy::move)
|
||||||
.def("tanh", &Handler::tanh, policy::move)
|
.def("tanh", &Handler::tanh, policy::move)
|
||||||
|
.def("hardSigmoid", &Handler::hardSigmoid, policy::move)
|
||||||
|
.def("hardSwish", &Handler::hardSwish, policy::move)
|
||||||
.def("softmax", &Handler::softmax, policy::move)
|
.def("softmax", &Handler::softmax, policy::move)
|
||||||
.def("abs", &Handler::abs, policy::move)
|
.def("abs", &Handler::abs, policy::move)
|
||||||
.def("sqrt", &Handler::sqrt, policy::move)
|
.def("sqrt", &Handler::sqrt, policy::move)
|
||||||
|
|
|
@ -46,6 +46,17 @@ template <typename T> class NaiveSigmoid : public NativeUnary<T> {
|
||||||
return 1 / (1 + pow(E_CONSTANT, -val));
|
return 1 / (1 + pow(E_CONSTANT, -val));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
template <typename T> class NaiveHardSigmoid : public NativeUnary<T> {
|
||||||
|
T doCompute(T val) const override {
|
||||||
|
return std::max(T(0), std::min(T(1), T(0.2) * val + T(0.5)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename T> class NaiveHardSwish : public NativeUnary<T> {
|
||||||
|
T doCompute(T val) const override {
|
||||||
|
return val *
|
||||||
|
std::max(T(0), std::min(T(1), val * T(1.0 / 6.0) + T(0.5)));
|
||||||
|
}
|
||||||
|
};
|
||||||
template <typename T> class NaiveTanh : public NativeUnary<T> {
|
template <typename T> class NaiveTanh : public NativeUnary<T> {
|
||||||
T doCompute(T val) const override {
|
T doCompute(T val) const override {
|
||||||
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
|
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
|
||||||
|
@ -105,6 +116,10 @@ REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
|
||||||
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
|
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
|
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
|
||||||
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
|
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::HardSigmoid, DataType::Float32,
|
||||||
|
NaiveHardSigmoid<float>, "hardSigmoidNaive_CPU_float32");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::HardSwish, DataType::Float32,
|
||||||
|
NaiveHardSwish<float>, "hardSwishNaive_CPU_float32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
|
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
|
||||||
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
|
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
|
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
|
||||||
|
|
|
@ -134,6 +134,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn,
|
||||||
"Relu_CUDA_Float32");
|
"Relu_CUDA_Float32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn,
|
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn,
|
||||||
"Sigmoid_CUDA_Float32");
|
"Sigmoid_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, DataType::Float32, UnaryCuda,
|
||||||
|
"Hard_Sigmoid_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::HardSwish, DataType::Float32, UnaryCuda,
|
||||||
|
"Hard_Swish_CUDA_Float32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn,
|
REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn,
|
||||||
"Tanh_CUDA_Float32");
|
"Tanh_CUDA_Float32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
||||||
|
|
|
@ -41,6 +41,23 @@ __global__ void _sigmoid_kernel(float *input, float *output, size_t n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void _hard_sigmoid_kernel(float *input, float *output, size_t n) {
|
||||||
|
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
size_t stride = blockDim.x * gridDim.x;
|
||||||
|
for (size_t i = index; i < n; i += stride) {
|
||||||
|
output[i] = max(0.0f, min(1.0f, 0.2f * input[i] + 0.5f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void _hard_swish_kernel(float *input, float *output, size_t n) {
|
||||||
|
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
size_t stride = blockDim.x * gridDim.x;
|
||||||
|
for (size_t i = index; i < n; i += stride) {
|
||||||
|
output[i] =
|
||||||
|
input[i] * max(0.f, min(1.f, (1.f / 6.f) * input[i] + 0.5f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void _tanh_kernel(float *input, float *output, size_t n) {
|
__global__ void _tanh_kernel(float *input, float *output, size_t n) {
|
||||||
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
size_t stride = blockDim.x * gridDim.x;
|
size_t stride = blockDim.x * gridDim.x;
|
||||||
|
@ -112,6 +129,18 @@ void sigmoid_kernel(float *input, float *output, size_t num) {
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
|
_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
|
void hard_sigmoid_kernel(float *input, float *output, size_t num) {
|
||||||
|
|
||||||
|
int blocksize = block_work_size();
|
||||||
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
|
_hard_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
|
}
|
||||||
|
void hard_swish_kernel(float *input, float *output, size_t num) {
|
||||||
|
|
||||||
|
int blocksize = block_work_size();
|
||||||
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
|
_hard_swish_kernel<<<gridsize, blocksize>>>(input, output, num);
|
||||||
|
}
|
||||||
void tanh_kernel(float *input, float *output, size_t num) {
|
void tanh_kernel(float *input, float *output, size_t num) {
|
||||||
|
|
||||||
int blocksize = block_work_size();
|
int blocksize = block_work_size();
|
||||||
|
|
|
@ -45,6 +45,8 @@ TEST(cuDNN_Unary, run) {
|
||||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
|
testUnary<HardSigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
|
testUnary<HardSwishObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<NegObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<ErfObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
|
|
Loading…
Reference in New Issue