forked from jiuyuan/InfiniTensor
feat: support unary int8
This commit is contained in:
parent
c63ed4326d
commit
03ed8c4de7
|
@ -74,8 +74,9 @@ class ReduceCudnnBase : public CudaKernelWithoutConfig {
|
|||
cudnnReduceTensorDescriptor_t reduceDesc;
|
||||
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
|
||||
checkCudnnError(cudnnSetReduceTensorDescriptor(
|
||||
reduceDesc, getReduceOp(), cudnnDataType, CUDNN_NOT_PROPAGATE_NAN,
|
||||
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES));
|
||||
reduceDesc, getReduceOp(), CUDNN_DATA_FLOAT,
|
||||
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
|
||||
CUDNN_32BIT_INDICES));
|
||||
|
||||
// get workspace
|
||||
size_t workspaceSize = 0;
|
||||
|
|
|
@ -94,6 +94,14 @@ __global__ void _sqrt_kernel(half *input, half *output, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void _sqrt_kernel(int8_t *input, int8_t *output, size_t n) {
|
||||
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
size_t stride = blockDim.x * gridDim.x;
|
||||
for (size_t i = index; i < n; i += stride) {
|
||||
output[i] = __fsqrt_rn(static_cast<float>(input[i]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _gelu_kernel(T *input, T *output, size_t n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
@ -258,6 +266,8 @@ void unary_kernel(const Operator &_op) {
|
|||
sqrt_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||
} else if (_op->getDType() == DataType::Float16) {
|
||||
sqrt_kernel<half>((half *)inputData, (half *)outputData, num);
|
||||
} else if (_op->getDType() == DataType::Int8) {
|
||||
sqrt_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData, num);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
@ -272,6 +282,8 @@ void unary_kernel(const Operator &_op) {
|
|||
neg_kernel<float>((float *)inputData, (float *)outputData, num);
|
||||
} else if (_op->getDType() == DataType::Float16) {
|
||||
neg_kernel<half>((half *)inputData, (half *)outputData, num);
|
||||
} else if (_op->getDType() == DataType::Int8) {
|
||||
neg_kernel<int8_t>((int8_t *)inputData, (int8_t *)outputData, num);
|
||||
} else {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue