add fp16 support to silu cuda op

This commit is contained in:
xiaonans 2024-02-19 11:39:21 +08:00
parent 936797b960
commit 0f1c04d864
1 changed files with 2 additions and 0 deletions

View File

@ -315,6 +315,8 @@ void unary_kernel(const Operator &_op) {
} else if (op->getOpType() == OpType::Silu) { } else if (op->getOpType() == OpType::Silu) {
if (_op->getDType() == DataType::Float32) { if (_op->getDType() == DataType::Float32) {
silu_kernel<float>((float *)inputData, (float *)outputData, num); silu_kernel<float>((float *)inputData, (float *)outputData, num);
} else if (_op->getDType() == DataType::Float16){
silu_kernel<half>((half *)inputData, (half *)outputData, num);
} else { } else {
IT_TODO_HALT(); IT_TODO_HALT();
} }