Add activation operators and kernels

* add code for activation operation

* add code for activation operation on GPU

* add test code for activation operation

* add code for activation operation

* add code for activation on gpu ,use cudnn

* add code for activation on GPU use cudnn

* Chore: add constants.h and remove comments

Co-authored-by: wanghailu <wanghailu@qiyuanlab.com>
Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
Hardy 2022-09-16 13:58:57 +08:00 committed by GitHub
parent 172d03d6f2
commit 6ac106cba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 551 additions and 1 deletions

5
include/core/constants.h Normal file
View File

@ -0,0 +1,5 @@
#pragma once
namespace infini {
constexpr double E_CONSTANT = 2.718281828459;
}

View File

@ -32,6 +32,10 @@ enum class OpType {
BatchNorm = 200,
Softmax,
Activation,
Relu,
Sigmoid,
Tanh,
Abs,
Resize,
//
MemBound = 300,
@ -75,6 +79,10 @@ class OpRegistry {
FOP(BatchNorm);
FOP(Softmax);
FOP(Activation);
FOP(Relu);
FOP(Sigmoid);
FOP(Tanh);
FOP(Abs);
//
FOP(MemBound);
default:
@ -204,4 +212,4 @@ namespace std {
template <> struct hash<infini::OpPerfKey> {
size_t operator()(const infini::OpPerfKey &key) const { return key.hash; }
};
} // namespace std
} // namespace std

33
include/cuda/cuda_unary.h Normal file
View File

@ -0,0 +1,33 @@
#pragma once
#include "operators/unary.h"
namespace infini {
void softmax_kernel(float *input, float *output, int num);
void relu_kernel(float *input, float *output, int num);
void sigmoid_kernel(float *input, float *output, int num);
void tanh_kernel(float *input, float *output, int num);
void abs_kernel(float *input, float *output, int num);
void unary_kernel(const Operator &_op) {
auto op = as<UnaryObj>(_op);
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
if (op->getOpType() == OpType::Softmax)
softmax_kernel(inputData, outputData, n * c * h * w);
else if (op->getOpType() == OpType::Relu)
relu_kernel(inputData, outputData, n * c * h * w);
else if (op->getOpType() == OpType::Sigmoid)
sigmoid_kernel(inputData, outputData, n * c * h * w);
else if (op->getOpType() == OpType::Tanh)
tanh_kernel(inputData, outputData, n * c * h * w);
else if (op->getOpType() == OpType::Abs)
abs_kernel(inputData, outputData, n * c * h * w);
else
IT_TODO_HALT();
}
}; // namespace infini

31
include/operators/unary.h Normal file
View File

@ -0,0 +1,31 @@
#pragma once
#include "core/operator.h"
namespace infini {
class UnaryObj : public OperatorObj {
public:
UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
#define DEFINE_UNARY_OBJ(prefix, type) \
class prefix##Obj : public UnaryObj { \
public: \
prefix##Obj(GraphObj *graph, Tensor input, Tensor output) \
: UnaryObj(type, graph, input, output) {} \
};
DEFINE_UNARY_OBJ(Relu, OpType::Relu)
DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
}; // namespace infini

99
src/kernels/cpu/unary.cc Normal file
View File

@ -0,0 +1,99 @@
#include "operators/unary.h"
#include "core/constants.h"
#include "core/kernel.h"
namespace infini {
template <typename T> class NativeUnary : public Kernel {
virtual T doCompute(T val) const = 0;
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *context) const override {
auto op = as<UnaryObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
auto outDim = op->getOutput()->getDims();
auto n = op->getOutput()->size();
for (size_t offset = 0; offset < n; offset++) {
outptr[offset] = doCompute(inptr[offset]);
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
return perfrcd;
}
};
template <typename T> class NaiveSoftmax : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *context) const override {
auto op = as<UnaryObj>(_op);
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
T *outptr = op->getOutput()->getRawDataPtr<T *>();
auto outDim = op->getOutput()->getDims();
auto n = op->getOutput()->size();
auto sum = T(0);
for (size_t offset = 0; offset < n; offset++) {
sum += pow(E_CONSTANT, inptr[offset]);
}
for (size_t offset = 0; offset < n; offset++) {
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
}
}
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
return perfrcd;
}
};
template <typename T> class NaiveRelu : public NativeUnary<T> {
T doCompute(T val) const override { return std::max(T(0), val); }
};
template <typename T> class NaiveSigmoid : public NativeUnary<T> {
T doCompute(T val) const override {
return 1 / (1 + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveTanh : public NativeUnary<T> {
T doCompute(T val) const override {
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
}
};
template <typename T> class NaiveAbs : public NativeUnary<T> {
T doCompute(T val) const override { return val < 0 ? -val : val; }
};
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32,
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>,
"reluNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
"tanhNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
"absNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
"absNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
NaiveSoftmax<float>, "softmaxNaive_CPU_float32");
}; // namespace infini

195
src/kernels/cuda/unary.cc Normal file
View File

@ -0,0 +1,195 @@
#include "operators/unary.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_unary.h"
namespace infini {
class UnaryCuda : public Kernel {
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
unary_kernel(_op);
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class ActivationCudnn : public Kernel {
virtual cudnnActivationMode_t getOpType() const = 0;
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
cudnnTensorDescriptor_t inputDesc, outputDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
// get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
// get op descriptor
cudnnActivationDescriptor_t activationDesc;
checkCudnnError(cudnnCreateActivationDescriptor(&activationDesc));
checkCudnnError(cudnnSetActivationDescriptor(
activationDesc, getOpType(), CUDNN_NOT_PROPAGATE_NAN, 0.0));
auto [alpha, beta] = getAlphBeta();
cudnnStatus_t stat = cudnnActivationForward(
context->cudnnHandle(), activationDesc, &alpha, inputDesc,
inputData, &beta, outputDesc, outputData);
if (stat != CUDNN_STATUS_SUCCESS)
return;
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyActivationDescriptor(activationDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class SoftmaxCudnn : public Kernel {
virtual cudnnSoftmaxAlgorithm_t getAlgorithmType() const = 0;
virtual cudnnSoftmaxMode_t getModeType() const = 0;
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *_context) const override {
auto op = as<UnaryObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
cudnnTensorDescriptor_t inputDesc, outputDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
// get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
// get outputs
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
auto [alpha, beta] = getAlphBeta();
cudnnStatus_t stat = cudnnSoftmaxForward(
context->cudnnHandle(), getAlgorithmType(), getModeType(), &alpha,
inputDesc, inputData, &beta, outputDesc, outputData);
if (stat != CUDNN_STATUS_SUCCESS)
return;
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
}
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
compute(_op, {}, _context);
}
// Premise: op is idempotent since it is called multiple times.
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord ret;
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
ret.time = timeit([&]() { compute(_op, _context); },
[&]() { context->sync(); });
return ret;
}
};
class ReluCudnn : public ActivationCudnn {
cudnnActivationMode_t getOpType() const override {
return CUDNN_ACTIVATION_RELU;
}
};
class SigmoidCudnn : public ActivationCudnn {
cudnnActivationMode_t getOpType() const override {
return CUDNN_ACTIVATION_SIGMOID;
}
};
class TanhCudnn : public ActivationCudnn {
cudnnActivationMode_t getOpType() const override {
return CUDNN_ACTIVATION_TANH;
}
};
class NormalSoftmaxCudnn : public SoftmaxCudnn {
cudnnSoftmaxAlgorithm_t getAlgorithmType() const override {
return CUDNN_SOFTMAX_ACCURATE;
}
cudnnSoftmaxMode_t getModeType() const override {
return CUDNN_SOFTMAX_MODE_INSTANCE;
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32,
NormalSoftmaxCudnn, "Softmax_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn,
"Relu_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn,
"Sigmoid_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn,
"Tanh_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
"Abs_CUDA_Float32");
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
// "Softmax_CUDA_Float32");
// REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, UnaryCuda,
// "Relu_CUDA_Float32");
// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, UnaryCuda,
// "Sigmoid_CUDA_Float32");
// REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, UnaryCuda,
// "Tanh_CUDA_Float32");
// REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
// "Abs_CUDA_Float32");
}; // namespace infini

94
src/kernels/cuda/unary.cu Normal file
View File

@ -0,0 +1,94 @@
#include "core/common.h"
#include "core/constants.h"
#include "cuda/cuda_common.h"
#include <math.h>
using infini::E_CONSTANT;
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _softmax_kernel1(float *input, float *output, int n) {
float sum = 0.0f;
for (size_t i = 0; i < n; ++i) {
sum += pow(E_CONSTANT, input[i]);
}
*output = sum;
}
__global__ void _softmax_kernel2(float *input, float *output, int n) {
float sum = *output;
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
output[i] = pow(E_CONSTANT, input[i]) / sum;
}
}
__global__ void _relu_kernel(float *input, float *output, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
output[i] = max(input[i], float(0));
}
}
__global__ void _sigmoid_kernel(float *input, float *output, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
output[i] = 1 / (1 + pow(E_CONSTANT, -input[i]));
}
}
__global__ void _tanh_kernel(float *input, float *output, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
output[i] = (pow(E_CONSTANT, input[i]) - pow(E_CONSTANT, -input[i])) /
(pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i]));
}
}
__global__ void _abs_kernel(float *input, float *output, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
output[i] = input[i] < 0 ? -input[i] : input[i];
}
}
namespace infini {
void softmax_kernel(float *input, float *output, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_softmax_kernel1<<<1, 1>>>(input, output, num);
_softmax_kernel2<<<blocksize, gridsize>>>(input, output, num);
}
void relu_kernel(float *input, float *output, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_relu_kernel<<<blocksize, gridsize>>>(input, output, num);
}
void sigmoid_kernel(float *input, float *output, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_sigmoid_kernel<<<blocksize, gridsize>>>(input, output, num);
}
void tanh_kernel(float *input, float *output, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_tanh_kernel<<<blocksize, gridsize>>>(input, output, num);
}
void abs_kernel(float *input, float *output, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_abs_kernel<<<blocksize, gridsize>>>(input, output, num);
}
}; // namespace infini

35
src/operators/unary.cc Normal file
View File

@ -0,0 +1,35 @@
#include "operators/unary.h"
namespace infini {
UnaryObj::UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output)
: OperatorObj(type, {input}, {output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0];
return {{A->getDims()}};
}
std::string UnaryObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> UnaryObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> UnaryObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
}; // namespace infini

View File

@ -0,0 +1,50 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
auto gpuOp = cudaGraph->addOp<T>(inputGpu, nullptr);
cudaGraph->dataMalloc();
cudaRuntime->run(cudaGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// CPU
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
auto cpuOp = cpuGraph->addOp<T>(inputCpu, nullptr);
cpuGraph->dataMalloc();
cpuRuntime->run(cpuGraph);
auto outputCpu = cpuOp->getOutput();
// Check
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
}
TEST(Unary, CuDNN) {
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SoftmaxObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini