diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 93a3cf6c..f4443564 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -315,6 +315,8 @@ void unary_kernel(const Operator &_op) { } else if (op->getOpType() == OpType::Silu) { if (_op->getDType() == DataType::Float32) { silu_kernel((float *)inputData, (float *)outputData, num); + } else if (_op->getDType() == DataType::Float16){ + silu_kernel((half *)inputData, (half *)outputData, num); } else { IT_TODO_HALT(); }