- unary support fp16

This commit is contained in:
kilinchange 2023-12-13 17:05:17 +08:00
parent ee4ecd27e2
commit 5af7f1e753
3 changed files with 163 additions and 93 deletions

View File

@ -3,48 +3,18 @@
#include "operators/unary.h"
namespace infini {
void softmax_kernel(float *input, float *output, size_t num);
void relu_kernel(float *input, float *output, size_t num);
void sigmoid_kernel(float *input, float *output, size_t num);
void tanh_kernel(float *input, float *output, size_t num);
void abs_kernel(float *input, float *output, size_t num);
void sqrt_kernel(float *input, float *output, size_t num);
void neg_kernel(float *input, float *output, size_t num);
void gelu_kernel(float *input, float *output, size_t num);
void erf_kernel(float *input, float *output, size_t num);
void hard_sigmoid_kernel(float *input, float *output, size_t num);
void hard_swish_kernel(float *input, float *output, size_t num);
template <typename T> void softmax_kernel(T *input, T *output, size_t num);
template <typename T> void relu_kernel(T *input, T *output, size_t num);
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num);
template <typename T> void tanh_kernel(T *input, T *output, size_t num);
template <typename T> void abs_kernel(T *input, T *output, size_t num);
template <typename T> void sqrt_kernel(T *input, T *output, size_t num);
template <typename T> void neg_kernel(T *input, T *output, size_t num);
template <typename T> void gelu_kernel(T *input, T *output, size_t num);
template <typename T> void erf_kernel(T *input, T *output, size_t num);
template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
void unary_kernel(const Operator &_op) {
auto op = as<UnaryObj>(_op);
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
size_t num = op->getOutput()->size();
if (op->getOpType() == OpType::Softmax)
softmax_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Relu)
relu_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sigmoid)
sigmoid_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::HardSigmoid)
hard_sigmoid_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::HardSwish)
hard_swish_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Tanh)
tanh_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Abs)
abs_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sqrt)
sqrt_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Gelu)
gelu_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Neg)
neg_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Erf)
erf_kernel(inputData, outputData, num);
else
IT_TODO_HALT();
}
void unary_kernel(const Operator &_op);
}; // namespace infini

View File

@ -2,6 +2,7 @@
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_unary.h"
#include "cuda/cuda_utility.h"
namespace infini {
@ -33,17 +34,17 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
while (stride.size() < 4)
stride.push_back(1);
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
checkCudnnError(cudnnSetTensorNdDescriptor(inputDesc, CUDNN_DATA_FLOAT,
dim.size(), dim.data(),
stride.data()));
checkCudnnError(cudnnSetTensorNdDescriptor(
inputDesc, cudnnDataType, dim.size(), dim.data(), stride.data()));
// get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
checkCudnnError(cudnnSetTensorNdDescriptor(outputDesc, CUDNN_DATA_FLOAT,
dim.size(), dim.data(),
stride.data()));
checkCudnnError(cudnnSetTensorNdDescriptor(
outputDesc, cudnnDataType, dim.size(), dim.data(), stride.data()));
// get op descriptor
cudnnActivationDescriptor_t activationDesc;
@ -86,16 +87,18 @@ class SoftmaxCudnn : public CudaKernelWithoutConfig {
memcpy(dim_array + (4 - dim.size()), dim.data(),
dim.size() * sizeof(int));
auto cudnnDataType = cudnnDataTypeConvert(op->getDType());
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, dim_array[0],
inputDesc, CUDNN_TENSOR_NCHW, cudnnDataType, dim_array[0],
dim_array[1], dim_array[2], dim_array[3]));
// get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, dim_array[0],
outputDesc, CUDNN_TENSOR_NCHW, cudnnDataType, dim_array[0],
dim_array[1], dim_array[2], dim_array[3]));
auto [alpha, beta] = getAlphBeta();
@ -142,8 +145,7 @@ REGISTER_KERNEL(Device::CUDA, OpType::Gelu, UnaryCuda, "Gelu_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Neg, UnaryCuda, "Neg_CUDA");
REGISTER_KERNEL(Device::CUDA, OpType::Erf, UnaryCuda, "Erf_CUDA");
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda,
// "Softmax_CUDA");
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, UnaryCuda, "Softmax_CUDA");
// REGISTER_KERNEL(Device::CUDA, OpType::Relu, UnaryCuda,
// "Relu_CUDA");
// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, UnaryCuda,

View File

@ -1,6 +1,7 @@
#include "core/common.h"
#include "core/constants.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_unary.h"
#include <math.h>
using infini::E_CONSTANT;
@ -8,15 +9,16 @@ constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _softmax_kernel1(float *input, float *output, size_t n) {
template <typename T>
__global__ void _softmax_kernel1(T *input, T *output, size_t n) {
float sum = 0.0f;
for (size_t i = 0; i < n; ++i) {
sum += pow(E_CONSTANT, input[i]);
}
*output = sum;
}
__global__ void _softmax_kernel2(float *input, float *output, size_t n) {
template <typename T>
__global__ void _softmax_kernel2(T *input, T *output, size_t n) {
float sum = *output;
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
@ -24,32 +26,32 @@ __global__ void _softmax_kernel2(float *input, float *output, size_t n) {
output[i] = pow(E_CONSTANT, input[i]) / sum;
}
}
__global__ void _relu_kernel(float *input, float *output, size_t n) {
template <typename T>
__global__ void _relu_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
output[i] = max(input[i], float(0));
}
}
__global__ void _sigmoid_kernel(float *input, float *output, size_t n) {
template <typename T>
__global__ void _sigmoid_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
output[i] = 1 / (1 + pow(E_CONSTANT, -input[i]));
}
}
__global__ void _hard_sigmoid_kernel(float *input, float *output, size_t n) {
template <typename T>
__global__ void _hard_sigmoid_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
output[i] = max(0.0f, min(1.0f, 0.2f * input[i] + 0.5f));
}
}
__global__ void _hard_swish_kernel(float *input, float *output, size_t n) {
template <typename T>
__global__ void _hard_swish_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
@ -57,8 +59,8 @@ __global__ void _hard_swish_kernel(float *input, float *output, size_t n) {
input[i] * max(0.f, min(1.f, (1.f / 6.f) * input[i] + 0.5f));
}
}
__global__ void _tanh_kernel(float *input, float *output, size_t n) {
template <typename T>
__global__ void _tanh_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
@ -66,8 +68,8 @@ __global__ void _tanh_kernel(float *input, float *output, size_t n) {
(pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i]));
}
}
__global__ void _abs_kernel(float *input, float *output, size_t n) {
template <typename T>
__global__ void _abs_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
@ -83,7 +85,16 @@ __global__ void _sqrt_kernel(float *input, float *output, size_t n) {
}
}
__global__ void _gelu_kernel(float *input, float *output, size_t n) {
__global__ void _sqrt_kernel(half *input, half *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (size_t i = index; i < n; i += stride) {
output[i] = hsqrt(input[i]);
}
}
template <typename T>
__global__ void _gelu_kernel(T *input, T *output, size_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
@ -91,8 +102,8 @@ __global__ void _gelu_kernel(float *input, float *output, size_t n) {
output[i] = 0.5 * x * (1 + erf(x / sqrt(2.0f)));
}
}
__global__ void _erf_kernel(float *input, float *output, size_t n) {
template <typename T>
__global__ void _erf_kernel(T *input, T *output, size_t n) {
size_t index = threadIdx.x + blockIdx.x * blockDim.x;
size_t stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
@ -110,71 +121,158 @@ __global__ void _neg_kernel(T *input, T *output, size_t n) {
}
namespace infini {
void softmax_kernel(float *input, float *output, size_t num) {
template <typename T> void softmax_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_softmax_kernel1<<<1, 1>>>(input, output, num);
_softmax_kernel2<<<gridsize, blocksize>>>(input, output, num);
_softmax_kernel1<T><<<1, 1>>>(input, output, num);
_softmax_kernel2<T><<<gridsize, blocksize>>>(input, output, num);
}
void relu_kernel(float *input, float *output, size_t num) {
template <typename T> void relu_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_relu_kernel<<<gridsize, blocksize>>>(input, output, num);
_relu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void sigmoid_kernel(float *input, float *output, size_t num) {
template <typename T> void sigmoid_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
_sigmoid_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void hard_sigmoid_kernel(float *input, float *output, size_t num) {
template <typename T>
void hard_sigmoid_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_hard_sigmoid_kernel<<<gridsize, blocksize>>>(input, output, num);
_hard_sigmoid_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void hard_swish_kernel(float *input, float *output, size_t num) {
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_hard_swish_kernel<<<gridsize, blocksize>>>(input, output, num);
_hard_swish_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void tanh_kernel(float *input, float *output, size_t num) {
template <typename T> void tanh_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_tanh_kernel<<<gridsize, blocksize>>>(input, output, num);
_tanh_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void abs_kernel(float *input, float *output, size_t num) {
template <typename T> void abs_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_abs_kernel<<<gridsize, blocksize>>>(input, output, num);
_abs_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void sqrt_kernel(float *input, float *output, size_t num) {
template <typename T> void sqrt_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_sqrt_kernel<<<gridsize, blocksize>>>(input, output, num);
_sqrt_kernel<<<gridsize, blocksize>>>((T *)input, (T *)output, num);
}
void gelu_kernel(float *input, float *output, size_t num) {
template <typename T> void gelu_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_gelu_kernel<<<gridsize, blocksize>>>(input, output, num);
_gelu_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void erf_kernel(float *input, float *output, size_t num) {
template <typename T> void erf_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_erf_kernel<<<gridsize, blocksize>>>(input, output, num);
_erf_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void neg_kernel(float *input, float *output, size_t num) {
template <typename T> void neg_kernel(T *input, T *output, size_t num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_neg_kernel<<<gridsize, blocksize>>>(input, output, num);
_neg_kernel<T><<<gridsize, blocksize>>>(input, output, num);
}
void unary_kernel(const Operator &_op) {
auto op = as<UnaryObj>(_op);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
size_t num = op->getOutput()->size();
if (op->getOpType() == OpType::Softmax) {
if (_op->getDType() == DataType::Float32) {
softmax_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Relu) {
if (_op->getDType() == DataType::Float32) {
relu_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Sigmoid) {
if (_op->getDType() == DataType::Float32) {
sigmoid_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::HardSigmoid) {
if (_op->getDType() == DataType::Float32) {
hard_sigmoid_kernel<float>((float *)inputData, (float *)outputData,
num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::HardSwish) {
if (_op->getDType() == DataType::Float32) {
hard_swish_kernel<float>((float *)inputData, (float *)outputData,
num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Tanh) {
if (_op->getDType() == DataType::Float32) {
tanh_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Abs) {
if (_op->getDType() == DataType::Float32) {
abs_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Sqrt) {
if (_op->getDType() == DataType::Float32) {
sqrt_kernel<float>((float *)inputData, (float *)outputData, num);
} else if (_op->getDType() == DataType::Float16) {
sqrt_kernel<half>((half *)inputData, (half *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Gelu) {
if (_op->getDType() == DataType::Float32) {
gelu_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else if (op->getOpType() == OpType::Neg) {
if (_op->getDType() == DataType::Float32) {
neg_kernel<float>((float *)inputData, (float *)outputData, num);
} else if (_op->getDType() == DataType::Float16) {
neg_kernel<half>((half *)inputData, (half *)outputData, num);
} else {
IT_TODO_HALT();
}
}
else if (op->getOpType() == OpType::Erf) {
if (_op->getDType() == DataType::Float32) {
erf_kernel<float>((float *)inputData, (float *)outputData, num);
} else {
IT_TODO_HALT();
}
} else
IT_TODO_HALT();
}
}; // namespace infini