diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index 31a39951..a2868205 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -3,48 +3,18 @@ #include "operators/unary.h" namespace infini { -void softmax_kernel(float *input, float *output, size_t num); -void relu_kernel(float *input, float *output, size_t num); -void sigmoid_kernel(float *input, float *output, size_t num); -void tanh_kernel(float *input, float *output, size_t num); -void abs_kernel(float *input, float *output, size_t num); -void sqrt_kernel(float *input, float *output, size_t num); -void neg_kernel(float *input, float *output, size_t num); -void gelu_kernel(float *input, float *output, size_t num); -void erf_kernel(float *input, float *output, size_t num); -void hard_sigmoid_kernel(float *input, float *output, size_t num); -void hard_swish_kernel(float *input, float *output, size_t num); +template void softmax_kernel(T *input, T *output, size_t num); +template void relu_kernel(T *input, T *output, size_t num); +template void sigmoid_kernel(T *input, T *output, size_t num); +template void tanh_kernel(T *input, T *output, size_t num); +template void abs_kernel(T *input, T *output, size_t num); +template void sqrt_kernel(T *input, T *output, size_t num); +template void neg_kernel(T *input, T *output, size_t num); +template void gelu_kernel(T *input, T *output, size_t num); +template void erf_kernel(T *input, T *output, size_t num); +template void hard_sigmoid_kernel(T *input, T *output, size_t num); +template void hard_swish_kernel(T *input, T *output, size_t num); -void unary_kernel(const Operator &_op) { - auto op = as(_op); - float *const inputData = (op->getInputs(0)->getRawDataPtr()); - float *const outputData = (op->getOutput()->getRawDataPtr()); - - size_t num = op->getOutput()->size(); - if (op->getOpType() == OpType::Softmax) - softmax_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Relu) - relu_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Sigmoid) - sigmoid_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::HardSigmoid) - hard_sigmoid_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::HardSwish) - hard_swish_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Tanh) - tanh_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Abs) - abs_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Sqrt) - sqrt_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Gelu) - gelu_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Neg) - neg_kernel(inputData, outputData, num); - else if (op->getOpType() == OpType::Erf) - erf_kernel(inputData, outputData, num); - else - IT_TODO_HALT(); -} +void unary_kernel(const Operator &_op); }; // namespace infini diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index 86cc1ded..7f6919b7 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -2,6 +2,7 @@ #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" #include "cuda/cuda_unary.h" +#include "cuda/cuda_utility.h" namespace infini { @@ -33,17 +34,17 @@ class ActivationCudnn : public CudaKernelWithoutConfig { while (stride.size() < 4) stride.push_back(1); + auto cudnnDataType = cudnnDataTypeConvert(op->getDType()); + // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); - checkCudnnError(cudnnSetTensorNdDescriptor(inputDesc, CUDNN_DATA_FLOAT, - dim.size(), dim.data(), - stride.data())); + checkCudnnError(cudnnSetTensorNdDescriptor( + inputDesc, cudnnDataType, dim.size(), dim.data(), stride.data())); // get outputs checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); - checkCudnnError(cudnnSetTensorNdDescriptor(outputDesc, CUDNN_DATA_FLOAT, - dim.size(), dim.data(), - stride.data())); + checkCudnnError(cudnnSetTensorNdDescriptor( + outputDesc, cudnnDataType, dim.size(), dim.data(), stride.data())); // get op descriptor cudnnActivationDescriptor_t activationDesc; @@ -86,16 +87,18 @@ class SoftmaxCudnn : public CudaKernelWithoutConfig { memcpy(dim_array + (4 - dim.size()), dim.data(), dim.size() * sizeof(int)); + auto cudnnDataType = cudnnDataTypeConvert(op->getDType()); + // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); checkCudnnError(cudnnSetTensor4dDescriptor( - inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, dim_array[0], + inputDesc, CUDNN_TENSOR_NCHW, cudnnDataType, dim_array[0], dim_array[1], dim_array[2], dim_array[3])); // get outputs checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc)); checkCudnnError(cudnnSetTensor4dDescriptor( - outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, dim_array[0], + outputDesc, CUDNN_TENSOR_NCHW, cudnnDataType, dim_array[0], dim_array[1], dim_array[2], dim_array[3])); auto [alpha, beta] = getAlphBeta(); @@ -142,8 +145,7 @@ REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA"); -// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda, -// "Softmax_CUDA"); +// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda, "Softmax_CUDA"); // REGISTER_KERNEL(Device::CUDA, OpType::Relu, UnaryCuda, // "Relu_CUDA"); // REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, UnaryCuda, diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 22e2e423..75c6ffdc 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -1,6 +1,7 @@ #include "core/common.h" #include "core/constants.h" #include "cuda/cuda_common.h" +#include "cuda/cuda_unary.h" #include using infini::E_CONSTANT; @@ -8,15 +9,16 @@ constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _softmax_kernel1(float *input, float *output, size_t n) { +template +__global__ void _softmax_kernel1(T *input, T *output, size_t n) { float sum = 0.0f; for (size_t i = 0; i < n; ++i) { sum += pow(E_CONSTANT, input[i]); } *output = sum; } - -__global__ void _softmax_kernel2(float *input, float *output, size_t n) { +template +__global__ void _softmax_kernel2(T *input, T *output, size_t n) { float sum = *output; size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; @@ -24,32 +26,32 @@ __global__ void _softmax_kernel2(float *input, float *output, size_t n) { output[i] = pow(E_CONSTANT, input[i]) / sum; } } - -__global__ void _relu_kernel(float *input, float *output, size_t n) { +template +__global__ void _relu_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { output[i] = max(input[i], float(0)); } } - -__global__ void _sigmoid_kernel(float *input, float *output, size_t n) { +template +__global__ void _sigmoid_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { output[i] = 1 / (1 + pow(E_CONSTANT, -input[i])); } } - -__global__ void _hard_sigmoid_kernel(float *input, float *output, size_t n) { +template +__global__ void _hard_sigmoid_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { output[i] = max(0.0f, min(1.0f, 0.2f * input[i] + 0.5f)); } } - -__global__ void _hard_swish_kernel(float *input, float *output, size_t n) { +template +__global__ void _hard_swish_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { @@ -57,8 +59,8 @@ __global__ void _hard_swish_kernel(float *input, float *output, size_t n) { input[i] * max(0.f, min(1.f, (1.f / 6.f) * input[i] + 0.5f)); } } - -__global__ void _tanh_kernel(float *input, float *output, size_t n) { +template +__global__ void _tanh_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { @@ -66,8 +68,8 @@ __global__ void _tanh_kernel(float *input, float *output, size_t n) { (pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i])); } } - -__global__ void _abs_kernel(float *input, float *output, size_t n) { +template +__global__ void _abs_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (size_t i = index; i < n; i += stride) { @@ -83,7 +85,16 @@ __global__ void _sqrt_kernel(float *input, float *output, size_t n) { } } -__global__ void _gelu_kernel(float *input, float *output, size_t n) { +__global__ void _sqrt_kernel(half *input, half *output, size_t n) { + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + size_t stride = blockDim.x * gridDim.x; + for (size_t i = index; i < n; i += stride) { + output[i] = hsqrt(input[i]); + } +} + +template +__global__ void _gelu_kernel(T *input, T *output, size_t n) { int index = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; for (int i = index; i < n; i += stride) { @@ -91,8 +102,8 @@ __global__ void _gelu_kernel(float *input, float *output, size_t n) { output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f))); } } - -__global__ void _erf_kernel(float *input, float *output, size_t n) { +template +__global__ void _erf_kernel(T *input, T *output, size_t n) { size_t index = threadIdx.x + blockIdx.x * blockDim.x; size_t stride = blockDim.x * gridDim.x; for (int i = index; i < n; i += stride) { @@ -110,71 +121,158 @@ __global__ void _neg_kernel(T *input, T *output, size_t n) { } namespace infini { -void softmax_kernel(float *input, float *output, size_t num) { +template void softmax_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _softmax_kernel1<<<1, 1>>>(input, output, num); - _softmax_kernel2<<>>(input, output, num); + _softmax_kernel1<<<1, 1>>>(input, output, num); + _softmax_kernel2<<>>(input, output, num); } -void relu_kernel(float *input, float *output, size_t num) { +template void relu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _relu_kernel<<>>(input, output, num); + _relu_kernel<<>>(input, output, num); } -void sigmoid_kernel(float *input, float *output, size_t num) { +template void sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _sigmoid_kernel<<>>(input, output, num); + _sigmoid_kernel<<>>(input, output, num); } -void hard_sigmoid_kernel(float *input, float *output, size_t num) { +template +void hard_sigmoid_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _hard_sigmoid_kernel<<>>(input, output, num); + _hard_sigmoid_kernel<<>>(input, output, num); } -void hard_swish_kernel(float *input, float *output, size_t num) { +template void hard_swish_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _hard_swish_kernel<<>>(input, output, num); + _hard_swish_kernel<<>>(input, output, num); } -void tanh_kernel(float *input, float *output, size_t num) { +template void tanh_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _tanh_kernel<<>>(input, output, num); + _tanh_kernel<<>>(input, output, num); } -void abs_kernel(float *input, float *output, size_t num) { +template void abs_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _abs_kernel<<>>(input, output, num); + _abs_kernel<<>>(input, output, num); } -void sqrt_kernel(float *input, float *output, size_t num) { +template void sqrt_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _sqrt_kernel<<>>(input, output, num); + _sqrt_kernel<<>>((T *)input, (T *)output, num); } -void gelu_kernel(float *input, float *output, size_t num) { + +template void gelu_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _gelu_kernel<<>>(input, output, num); + _gelu_kernel<<>>(input, output, num); } -void erf_kernel(float *input, float *output, size_t num) { +template void erf_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _erf_kernel<<>>(input, output, num); + _erf_kernel<<>>(input, output, num); } -void neg_kernel(float *input, float *output, size_t num) { +template void neg_kernel(T *input, T *output, size_t num) { int blocksize = block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size(); - _neg_kernel<<>>(input, output, num); + _neg_kernel<<>>(input, output, num); } + +void unary_kernel(const Operator &_op) { + auto op = as(_op); + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + size_t num = op->getOutput()->size(); + if (op->getOpType() == OpType::Softmax) { + if (_op->getDType() == DataType::Float32) { + softmax_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Relu) { + if (_op->getDType() == DataType::Float32) { + relu_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Sigmoid) { + if (_op->getDType() == DataType::Float32) { + sigmoid_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::HardSigmoid) { + if (_op->getDType() == DataType::Float32) { + hard_sigmoid_kernel((float *)inputData, (float *)outputData, + num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::HardSwish) { + if (_op->getDType() == DataType::Float32) { + hard_swish_kernel((float *)inputData, (float *)outputData, + num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Tanh) { + if (_op->getDType() == DataType::Float32) { + tanh_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Abs) { + if (_op->getDType() == DataType::Float32) { + abs_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Sqrt) { + if (_op->getDType() == DataType::Float32) { + sqrt_kernel((float *)inputData, (float *)outputData, num); + } else if (_op->getDType() == DataType::Float16) { + sqrt_kernel((half *)inputData, (half *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Gelu) { + if (_op->getDType() == DataType::Float32) { + gelu_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else if (op->getOpType() == OpType::Neg) { + if (_op->getDType() == DataType::Float32) { + neg_kernel((float *)inputData, (float *)outputData, num); + } else if (_op->getDType() == DataType::Float16) { + neg_kernel((half *)inputData, (half *)outputData, num); + } else { + IT_TODO_HALT(); + } + } + + else if (op->getOpType() == OpType::Erf) { + if (_op->getDType() == DataType::Float32) { + erf_kernel((float *)inputData, (float *)outputData, num); + } else { + IT_TODO_HALT(); + } + } else + IT_TODO_HALT(); +} + }; // namespace infini