feat:support int8 for gather

This commit is contained in:
OdinaryWord 2023-12-14 13:28:41 +08:00
parent db8c3eec15
commit 2af4c1276b
1 changed files with 5 additions and 1 deletions

View File

@ -21,10 +21,14 @@ class GatherCuda : public CudaKernelWithoutConfig {
if (op->getDType() == DataType::Float32) {
gather_kernel<float>((float *)inputData, (float *)outputData,
metaData, op->getOutput()->size());
} else if (op->getDType() == DataType::Float32) {
} else if (op->getDType() == DataType::Float16) {
gather_kernel<half>((half *)inputData, (half *)outputData, metaData,
op->getOutput()->size());
}
else if (op->getDType() == DataType::Int8) {
gather_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData, metaData,
op->getOutput()->size());
}
}
};