forked from jiuyuan/InfiniTensor
feat:support int8 for gather
This commit is contained in:
parent
db8c3eec15
commit
2af4c1276b
|
@ -21,10 +21,14 @@ class GatherCuda : public CudaKernelWithoutConfig {
|
|||
if (op->getDType() == DataType::Float32) {
|
||||
gather_kernel<float>((float *)inputData, (float *)outputData,
|
||||
metaData, op->getOutput()->size());
|
||||
} else if (op->getDType() == DataType::Float32) {
|
||||
} else if (op->getDType() == DataType::Float16) {
|
||||
gather_kernel<half>((half *)inputData, (half *)outputData, metaData,
|
||||
op->getOutput()->size());
|
||||
}
|
||||
else if (op->getDType() == DataType::Int8) {
|
||||
gather_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData, metaData,
|
||||
op->getOutput()->size());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue