feat: support unary int8

This commit is contained in:
kilinchange 2023-12-18 15:02:44 +08:00
parent c63ed4326d
commit 03ed8c4de7
2 changed files with 15 additions and 2 deletions

View File

@ -74,8 +74,9 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
cudnnReduceTensorDescriptor_t reduceDesc;
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
checkCudnnError(cudnnSetReduceTensorDescriptor(
reduceDesc, getReduceOp(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));
reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
CUDNN_32BIT_INDICES));
// get workspace
size_t workspaceSize = 0;

View File

@ -94,6 +94,14 @@ __global__ void _sqrt_kernel(half *input, half *output, size_t n) {
}
}
__global__ void _sqrt_kernel(int8_t *input, int8_t *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
output[i] = __fsqrt_rn(static_cast<float>(input[i]));
}
}
template <typename T>
__global__ void _gelu_kernel(T *input, T *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
@ -258,6 +266,8 @@ void unary_kernel(const Operator &_op) {
sqrt_kernel<float>((float *)inputData, (float *)outputData, num);
} else if (_op->getDType() == DataType::Float16) {
sqrt_kernel<half>((half *)inputData, (half *)outputData, num);
} else if (_op->getDType() == DataType::Int8) {
sqrt_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData, num);
} else {
IT_TODO_HALT();
}
@ -272,6 +282,8 @@ void unary_kernel(const Operator &_op) {
neg_kernel<float>((float *)inputData, (float *)outputData, num);
} else if (_op->getDType() == DataType::Float16) {
neg_kernel<half>((half *)inputData, (half *)outputData, num);
} else if (_op->getDType() == DataType::Int8) {
neg_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData, num);
} else {
IT_TODO_HALT();
}