Add ELU operator (#237)
* Add ELU operator * Format code using clang-format * Format code using clang-format * Format code using clang-format * Format code using clang-format * fix test * build.yml * Update cuda_unary.h * Update unary.h * Update unary.cc * Update unary.cc --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
62b8866022
commit
d0cfd1e40a
|
@ -36,6 +36,11 @@ jobs:
|
|||
- name: Install libdw
|
||||
run: sudo apt-get update && sudo apt-get install libdw-dev
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install numpy==1.22.2 onnxruntime
|
||||
|
||||
# - name: Cache protobuf
|
||||
# id: cache-protobuf
|
||||
# uses: actions/cache@v3
|
||||
|
@ -79,7 +84,4 @@ jobs:
|
|||
|
||||
- name: Test onnx frontend
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install onnxruntime
|
||||
pip install numpy==1.22.2
|
||||
make test-onnx
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "core/graph.h"
|
||||
#include "core/operator.h"
|
||||
#include "core/runtime.h"
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
@ -69,6 +70,7 @@ class GraphHandlerObj {
|
|||
Tensor identity(Tensor x, Tensor y);
|
||||
Tensor flatten(Tensor s, Tensor y, int axis);
|
||||
Tensor pRelu(Tensor x, Tensor slope, Tensor y);
|
||||
Tensor elu(Tensor x, Tensor y, float alpha);
|
||||
Tensor clip(Tensor x, Tensor y, std::optional<float> min,
|
||||
std::optional<float> max);
|
||||
Tensor transpose(Tensor data, Tensor transposed, Shape perm);
|
||||
|
|
|
@ -17,10 +17,9 @@ template <typename T> void hard_sigmoid_kernel(T *input, T *output, size_t num);
|
|||
template <typename T> void hard_swish_kernel(T *input, T *output, size_t num);
|
||||
template <typename T>
|
||||
void leaky_relu_kernel(T *input, T *output, size_t num, float alpha);
|
||||
|
||||
template <typename INPUT, typename OUTPUT>
|
||||
void cast_kernel(INPUT *input, OUTPUT *output, size_t num);
|
||||
|
||||
void elu_kernel(const float *input, float *output, size_t size, float alpha);
|
||||
void unary_kernel(const Operator &_op);
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -47,6 +47,23 @@ class ClipObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class EluObj : public OperatorObj {
|
||||
public:
|
||||
EluObj(GraphObj *graph, Tensor input, Tensor output, float alpha);
|
||||
OP_CLONE(EluObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
float getAlpha() const { return alpha; }
|
||||
float alpha;
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class HardtanhObj : public OperatorObj {
|
||||
public:
|
||||
HardtanhObj(GraphObj *graph, Tensor input, Tensor output, float min,
|
||||
|
|
|
@ -180,6 +180,12 @@ class OnnxStub:
|
|||
d[0],
|
||||
d[1],
|
||||
)
|
||||
elif node.op_type == "Elu":
|
||||
attributes = _parse_attribute(node, {"alpha": 1.0})
|
||||
alpha = attributes["alpha"]
|
||||
tensors[node.output[0]] = self.handler.elu(
|
||||
tensors[node.input[0]], tensors.get(node.output[0]), alpha
|
||||
)
|
||||
elif node.op_type == "ConvTranspose":
|
||||
attributes = _parse_attribute(
|
||||
node,
|
||||
|
@ -1174,6 +1180,13 @@ class OnnxStub:
|
|||
group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1],
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpTypeId.Elu:
|
||||
alpha = backend.elu_alpha_of(op)
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
"Elu", inputs, outputs, name, alpha=alpha
|
||||
)
|
||||
)
|
||||
elif ty == backend.OpTypeId.ConvTranspose:
|
||||
ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op)
|
||||
ctx.push_node(
|
||||
|
|
|
@ -36,6 +36,17 @@ namespace infini {
|
|||
static DataType dtype_repr_convert(int);
|
||||
static CastType inferCastType(Tensor input, int to);
|
||||
|
||||
Tensor GraphHandlerObj::elu(Tensor input, Tensor output, float alpha) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<EluObj>(std::move(input), output, alpha);
|
||||
return output;
|
||||
} else {
|
||||
auto new_output = g->addTensor(input->getDims(), input->getDType());
|
||||
g->addOpWithOutputs<EluObj>(std::move(input), new_output, alpha);
|
||||
return new_output;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
|
||||
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
|
||||
}
|
||||
|
|
|
@ -120,6 +120,7 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Where)
|
||||
.VALUE(OpType, DepthToSpace)
|
||||
.VALUE(OpType, LRN)
|
||||
.VALUE(OpType, Elu)
|
||||
.export_values();
|
||||
|
||||
#undef VALUE
|
||||
|
@ -203,6 +204,12 @@ static std::tuple<bool, bool> matmul_attrs_of(Operator op) {
|
|||
return std::make_tuple(matmul->getTransA(), matmul->getTransB());
|
||||
}
|
||||
|
||||
static float elu_alpha_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Elu);
|
||||
auto elu = dynamic_cast<const EluObj *>(op.get());
|
||||
return elu->getAlpha();
|
||||
}
|
||||
|
||||
static std::tuple<float, float, bool> batch_norm_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::BatchNormalization);
|
||||
auto batchnorm = dynamic_cast<const BatchNormObj *>(op.get());
|
||||
|
@ -368,7 +375,8 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(depth_to_space_attrs_of)
|
||||
.FUNCTION(squeeze_axes_of)
|
||||
.FUNCTION(unsqueeze_axes_of)
|
||||
.FUNCTION(lrn_attrs_of);
|
||||
.FUNCTION(lrn_attrs_of)
|
||||
.FUNCTION(elu_alpha_of);
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
@ -501,6 +509,7 @@ void init_graph_builder(py::module &m) {
|
|||
policy::reference);
|
||||
py::class_<Handler>(m, "GraphHandler")
|
||||
.def(py::init<Runtime>())
|
||||
.def("elu", &Handler::elu, policy::move)
|
||||
.def("tensor", &Handler::tensor, policy::move)
|
||||
.def("conv", &Handler::conv, policy::move)
|
||||
.def("convTransposed2d", &Handler::convTransposed2d, policy::move)
|
||||
|
|
|
@ -13,6 +13,20 @@ class UnaryCuda : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class EluCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<EluObj>(_op);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
size_t size = op->getInputs(0)->size();
|
||||
elu_kernel((float *)inputData, (float *)outputData, size,
|
||||
op->getAlpha());
|
||||
}
|
||||
};
|
||||
|
||||
class CastCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
|
@ -192,6 +206,7 @@ class TanhCudnn : public ActivationCudnn {
|
|||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Relu, ReluCudnn, "Relu_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, SigmoidCudnn, "Sigmoid_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Elu, EluCuda, "Elu_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, UnaryCuda,
|
||||
"Hard_Sigmoid_CUDA");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::HardSwish, UnaryCuda, "Hard_Swish_CUDA");
|
||||
|
|
|
@ -94,6 +94,15 @@ __global__ void _sqrt_kernel(half *input, half *output, size_t n) {
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void _elu_kernel(const float *input, float *output, int size, float alpha) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (index < size) {
|
||||
float x = input[index];
|
||||
output[index] = (x >= 0) ? x : alpha * (expf(x) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void _gelu_kernel(T *input, T *output, size_t n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
@ -360,6 +369,12 @@ void leaky_relu_kernel(T *input, T *output, size_t num, float alphaValue) {
|
|||
alphaValue);
|
||||
}
|
||||
|
||||
void elu_kernel(const float *input, float *output, int size, float alpha) {
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (size + blocksize - 1) / blocksize;
|
||||
_elu_kernel<<<gridsize, blocksize>>>(input, output, size, alpha);
|
||||
}
|
||||
|
||||
template void cast_kernel<float, half>(float *input, half *output, size_t num);
|
||||
template void cast_kernel<half, float>(half *input, float *output, size_t num);
|
||||
template void cast_kernel<float, int32_t>(float *input, int32_t *output,
|
||||
|
|
|
@ -342,4 +342,33 @@ vector<int> LogObj::getWorkloadVector() const {
|
|||
|
||||
vector<int> LogObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
EluObj::EluObj(GraphObj *graph, Tensor input, Tensor output, float alpha)
|
||||
: OperatorObj(OpType::Elu, {input}, {output}), alpha(alpha) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> EluObj::inferShape(const TensorVec &inputs) {
|
||||
return {{inputs[0]->getDims()}};
|
||||
}
|
||||
|
||||
std::string EluObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Elu[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "alpha=" << alpha << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> EluObj::getWorkloadVector() const {
|
||||
vector<int> ret = getOutput()->getDims();
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> EluObj::getOpAttrVector() const {
|
||||
return {type.underlying(), static_cast<int>(alpha)};
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -96,6 +96,29 @@ TEST(LeakyRelu, Cuda_WithAlpha) {
|
|||
-0.015, -0.01, 1.0, 2.0, 3.0}));
|
||||
}
|
||||
|
||||
TEST(Elu, Cuda) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto inputGpu = gCuda->cloneTensor(input);
|
||||
auto op = gCuda->addOp<EluObj>(inputGpu, nullptr, 1.0f);
|
||||
gCuda->dataMalloc();
|
||||
inputGpu->setData(IncrementalGenerator());
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
oCpu->printData();
|
||||
EXPECT_TRUE(oCpu->equalData(
|
||||
vector<float>{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}));
|
||||
}
|
||||
|
||||
TEST(cuDNN_Unary, run) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SiluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
|
|
Loading…
Reference in New Issue