From 03ed8c4de772a29760c46d885f328d52822a20e0 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 18 Dec 2023 15:02:44 +0800 Subject: [PATCH] feat: support unary int8 --- src/kernels/cuda/reduce.cc | 5 +++-- src/kernels/cuda/unary.cu | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/kernels/cuda/reduce.cc b/src/kernels/cuda/reduce.cc index 531c09d0..d0c9a549 100644 --- a/src/kernels/cuda/reduce.cc +++ b/src/kernels/cuda/reduce.cc @@ -74,8 +74,9 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig { cudnnReduceTensorDescriptor_t reduceDesc; checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc)); checkCudnnError(cudnnSetReduceTensorDescriptor( - reduceDesc, getReduceOp(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN, - CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES)); + reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT, + CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, + CUDNN_32BIT_INDICES)); // get workspace size_t workspaceSize = 0; diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index afd7f02a..37a90765 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -94,6 +94,14 @@ __global__ void _sqrt_kernel(half *input, half *output, size_t n) { } } +__global__ void _sqrt_kernel(int8_t *input, int8_t *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { + output[i] = __fsqrt_rn(static_cast(input[i])); + } +} + template __global__ void _gelu_kernel(T *input, T *output, size_t n) { int index = threadIdx.x + blockIdx.x * blockDim.x; @@ -258,6 +266,8 @@ void unary_kernel(const Operator &_op) { sqrt_kernel((float *)inputData, (float *)outputData, num); } else if (_op->getDType() == DataType::Float16) { sqrt_kernel((half *)inputData, (half *)outputData, num); + } else if (_op->getDType() == DataType::Int8) { + sqrt_kernel((int8_t *)inputData, (int8_t *)outputData, num); } else { IT_TODO_HALT(); } @@ -272,6 +282,8 @@ void unary_kernel(const Operator &_op) { neg_kernel((float *)inputData, (float *)outputData, num); } else if (_op->getDType() == DataType::Float16) { neg_kernel((half *)inputData, (half *)outputData, num); + } else if (_op->getDType() == DataType::Int8) { + neg_kernel((int8_t *)inputData, (int8_t *)outputData, num); } else { IT_TODO_HALT(); }