forked from jiuyuan/InfiniTensor
Add activation operators and kernels
* add code for activation operation * add code for activation operation on GPU * add test code for activation operation * add code for activation operation * add code for activation on gpu ,use cudnn * add code for activation on GPU use cudnn * Chore: add constants.h and remove comments Co-authored-by: wanghailu <wanghailu@qiyuanlab.com> Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
172d03d6f2
commit
6ac106cba4
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
constexpr double E_CONSTANT = 2.718281828459;
|
||||
}
|
|
@ -32,6 +32,10 @@ enum class OpType {
|
|||
BatchNorm = 200,
|
||||
Softmax,
|
||||
Activation,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
Abs,
|
||||
Resize,
|
||||
//
|
||||
MemBound = 300,
|
||||
|
@ -75,6 +79,10 @@ class OpRegistry {
|
|||
FOP(BatchNorm);
|
||||
FOP(Softmax);
|
||||
FOP(Activation);
|
||||
FOP(Relu);
|
||||
FOP(Sigmoid);
|
||||
FOP(Tanh);
|
||||
FOP(Abs);
|
||||
//
|
||||
FOP(MemBound);
|
||||
default:
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
void softmax_kernel(float *input, float *output, int num);
|
||||
void relu_kernel(float *input, float *output, int num);
|
||||
void sigmoid_kernel(float *input, float *output, int num);
|
||||
void tanh_kernel(float *input, float *output, int num);
|
||||
void abs_kernel(float *input, float *output, int num);
|
||||
|
||||
void unary_kernel(const Operator &_op) {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
|
||||
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
if (op->getOpType() == OpType::Softmax)
|
||||
softmax_kernel(inputData, outputData, n * c * h * w);
|
||||
else if (op->getOpType() == OpType::Relu)
|
||||
relu_kernel(inputData, outputData, n * c * h * w);
|
||||
else if (op->getOpType() == OpType::Sigmoid)
|
||||
sigmoid_kernel(inputData, outputData, n * c * h * w);
|
||||
else if (op->getOpType() == OpType::Tanh)
|
||||
tanh_kernel(inputData, outputData, n * c * h * w);
|
||||
else if (op->getOpType() == OpType::Abs)
|
||||
abs_kernel(inputData, outputData, n * c * h * w);
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,31 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class UnaryObj : public OperatorObj {
|
||||
public:
|
||||
UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
#define DEFINE_UNARY_OBJ(prefix, type) \
|
||||
class prefix##Obj : public UnaryObj { \
|
||||
public: \
|
||||
prefix##Obj(GraphObj *graph, Tensor input, Tensor output) \
|
||||
: UnaryObj(type, graph, input, output) {} \
|
||||
};
|
||||
|
||||
DEFINE_UNARY_OBJ(Relu, OpType::Relu)
|
||||
DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
|
||||
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
||||
DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
||||
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
|
||||
}; // namespace infini
|
|
@ -0,0 +1,99 @@
|
|||
#include "operators/unary.h"
|
||||
#include "core/constants.h"
|
||||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativeUnary : public Kernel {
|
||||
virtual T doCompute(T val) const = 0;
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
||||
auto outDim = op->getOutput()->getDims();
|
||||
auto n = op->getOutput()->size();
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
outptr[offset] = doCompute(inptr[offset]);
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||
return perfrcd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveSoftmax : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
||||
auto outDim = op->getOutput()->getDims();
|
||||
auto n = op->getOutput()->size();
|
||||
auto sum = T(0);
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
sum += pow(E_CONSTANT, inptr[offset]);
|
||||
}
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
outptr[offset] = pow(E_CONSTANT, inptr[offset]) / sum;
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||
return perfrcd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveRelu : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return std::max(T(0), val); }
|
||||
};
|
||||
template <typename T> class NaiveSigmoid : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return 1 / (1 + pow(E_CONSTANT, -val));
|
||||
}
|
||||
};
|
||||
template <typename T> class NaiveTanh : public NativeUnary<T> {
|
||||
T doCompute(T val) const override {
|
||||
return (pow(E_CONSTANT, val) - pow(E_CONSTANT, -val)) /
|
||||
(pow(E_CONSTANT, val) + pow(E_CONSTANT, -val));
|
||||
}
|
||||
};
|
||||
template <typename T> class NaiveAbs : public NativeUnary<T> {
|
||||
T doCompute(T val) const override { return val < 0 ? -val : val; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::UInt32,
|
||||
NaiveRelu<uint32_t>, "reluNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Relu, DataType::Float32, NaiveRelu<float>,
|
||||
"reluNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::UInt32,
|
||||
NaiveSigmoid<uint32_t>, "sigmoidNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sigmoid, DataType::Float32,
|
||||
NaiveSigmoid<float>, "sigmoidNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::UInt32,
|
||||
NaiveTanh<uint32_t>, "tanhNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Tanh, DataType::Float32, NaiveTanh<float>,
|
||||
"tanhNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
|
||||
"absNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
|
||||
"absNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
|
||||
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
|
||||
NaiveSoftmax<float>, "softmaxNaive_CPU_float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,195 @@
|
|||
#include "operators/unary.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class UnaryCuda : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *_context) const override {
|
||||
unary_kernel(_op);
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class ActivationCudnn : public Kernel {
|
||||
virtual cudnnActivationMode_t getOpType() const = 0;
|
||||
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cudnnTensorDescriptor_t inputDesc, outputDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
// get outputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
// get op descriptor
|
||||
cudnnActivationDescriptor_t activationDesc;
|
||||
checkCudnnError(cudnnCreateActivationDescriptor(&activationDesc));
|
||||
checkCudnnError(cudnnSetActivationDescriptor(
|
||||
activationDesc, getOpType(), CUDNN_NOT_PROPAGATE_NAN, 0.0));
|
||||
|
||||
auto [alpha, beta] = getAlphBeta();
|
||||
cudnnStatus_t stat = cudnnActivationForward(
|
||||
context->cudnnHandle(), activationDesc, &alpha, inputDesc,
|
||||
inputData, &beta, outputDesc, outputData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||
// whether sync is required before destories.
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(activationDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class SoftmaxCudnn : public Kernel {
|
||||
virtual cudnnSoftmaxAlgorithm_t getAlgorithmType() const = 0;
|
||||
virtual cudnnSoftmaxMode_t getModeType() const = 0;
|
||||
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cudnnTensorDescriptor_t inputDesc, outputDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
// get outputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
auto [alpha, beta] = getAlphBeta();
|
||||
cudnnStatus_t stat = cudnnSoftmaxForward(
|
||||
context->cudnnHandle(), getAlgorithmType(), getModeType(), &alpha,
|
||||
inputDesc, inputData, &beta, outputDesc, outputData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||
// whether sync is required before destories.
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class ReluCudnn : public ActivationCudnn {
|
||||
cudnnActivationMode_t getOpType() const override {
|
||||
return CUDNN_ACTIVATION_RELU;
|
||||
}
|
||||
};
|
||||
|
||||
class SigmoidCudnn : public ActivationCudnn {
|
||||
cudnnActivationMode_t getOpType() const override {
|
||||
return CUDNN_ACTIVATION_SIGMOID;
|
||||
}
|
||||
};
|
||||
|
||||
class TanhCudnn : public ActivationCudnn {
|
||||
cudnnActivationMode_t getOpType() const override {
|
||||
return CUDNN_ACTIVATION_TANH;
|
||||
}
|
||||
};
|
||||
|
||||
class NormalSoftmaxCudnn : public SoftmaxCudnn {
|
||||
cudnnSoftmaxAlgorithm_t getAlgorithmType() const override {
|
||||
return CUDNN_SOFTMAX_ACCURATE;
|
||||
}
|
||||
cudnnSoftmaxMode_t getModeType() const override {
|
||||
return CUDNN_SOFTMAX_MODE_INSTANCE;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32,
|
||||
NormalSoftmaxCudnn, "Softmax_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn,
|
||||
"Relu_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn,
|
||||
"Sigmoid_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn,
|
||||
"Tanh_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
||||
"Abs_CUDA_Float32");
|
||||
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
|
||||
// "Softmax_CUDA_Float32");
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, UnaryCuda,
|
||||
// "Relu_CUDA_Float32");
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, UnaryCuda,
|
||||
// "Sigmoid_CUDA_Float32");
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, UnaryCuda,
|
||||
// "Tanh_CUDA_Float32");
|
||||
// REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
||||
// "Abs_CUDA_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,94 @@
|
|||
#include "core/common.h"
|
||||
#include "core/constants.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include <math.h>
|
||||
|
||||
using infini::E_CONSTANT;
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _softmax_kernel1(float *input, float *output, int n) {
|
||||
float sum = 0.0f;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
sum += pow(E_CONSTANT, input[i]);
|
||||
}
|
||||
*output = sum;
|
||||
}
|
||||
|
||||
__global__ void _softmax_kernel2(float *input, float *output, int n) {
|
||||
float sum = *output;
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
output[i] = pow(E_CONSTANT, input[i]) / sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _relu_kernel(float *input, float *output, int n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
output[i] = max(input[i], float(0));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _sigmoid_kernel(float *input, float *output, int n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
output[i] = 1 / (1 + pow(E_CONSTANT, -input[i]));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _tanh_kernel(float *input, float *output, int n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
output[i] = (pow(E_CONSTANT, input[i]) - pow(E_CONSTANT, -input[i])) /
|
||||
(pow(E_CONSTANT, input[i]) + pow(E_CONSTANT, -input[i]));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _abs_kernel(float *input, float *output, int n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
output[i] = input[i] < 0 ? -input[i] : input[i];
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void softmax_kernel(float *input, float *output, int num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_softmax_kernel1<<<1, 1>>>(input, output, num);
|
||||
_softmax_kernel2<<<blocksize, gridsize>>>(input, output, num);
|
||||
}
|
||||
void relu_kernel(float *input, float *output, int num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_relu_kernel<<<blocksize, gridsize>>>(input, output, num);
|
||||
}
|
||||
void sigmoid_kernel(float *input, float *output, int num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_sigmoid_kernel<<<blocksize, gridsize>>>(input, output, num);
|
||||
}
|
||||
void tanh_kernel(float *input, float *output, int num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_tanh_kernel<<<blocksize, gridsize>>>(input, output, num);
|
||||
}
|
||||
void abs_kernel(float *input, float *output, int num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_abs_kernel<<<blocksize, gridsize>>>(input, output, num);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,35 @@
|
|||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
UnaryObj::UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output)
|
||||
: OperatorObj(type, {input}, {output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
||||
std::string UnaryObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> UnaryObj::getWorkloadVector() const {
|
||||
vector<int> ret{enum_to_underlying(type)};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> UnaryObj::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,50 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = cudaGraph->addOp<T>(inputGpu, nullptr);
|
||||
cudaGraph->dataMalloc();
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu, nullptr);
|
||||
cpuGraph->dataMalloc();
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(Unary, CuDNN) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SoftmaxObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue