diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5e1e7626..675693bd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,6 +36,11 @@ jobs: - name: Install libdw run: sudo apt-get update && sudo apt-get install libdw-dev + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install numpy==1.22.2 onnxruntime + # - name: Cache protobuf # id: cache-protobuf # uses: actions/cache@v3 @@ -79,7 +84,4 @@ jobs: - name: Test onnx frontend run: | - python -m pip install --upgrade pip - pip install onnxruntime - pip install numpy==1.22.2 make test-onnx diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 7996cabc..7f29cdd6 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -1,6 +1,7 @@ #pragma once #include "core/graph.h" +#include "core/operator.h" #include "core/runtime.h" #include #include @@ -69,6 +70,7 @@ class GraphHandlerObj { Tensor identity(Tensor x, Tensor y); Tensor flatten(Tensor s, Tensor y, int axis); Tensor pRelu(Tensor x, Tensor slope, Tensor y); + Tensor elu(Tensor x, Tensor y, float alpha); Tensor clip(Tensor x, Tensor y, std::optional min, std::optional max); Tensor transpose(Tensor data, Tensor transposed, Shape perm); diff --git a/include/cuda/cuda_unary.h b/include/cuda/cuda_unary.h index 3be4bbae..80a1e8b9 100644 --- a/include/cuda/cuda_unary.h +++ b/include/cuda/cuda_unary.h @@ -17,10 +17,9 @@ template void hard_sigmoid_kernel(T *input, T *output, size_t num); template void hard_swish_kernel(T *input, T *output, size_t num); template void leaky_relu_kernel(T *input, T *output, size_t num, float alpha); - template void cast_kernel(INPUT *input, OUTPUT *output, size_t num); - +void elu_kernel(const float *input, float *output, size_t size, float alpha); void unary_kernel(const Operator &_op); }; // namespace infini diff --git a/include/operators/unary.h b/include/operators/unary.h index 5af161f7..b9651ea3 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -47,6 +47,23 @@ class ClipObj : public OperatorObj { vector getOpAttrVector() const override; }; +class EluObj : public OperatorObj { + public: + EluObj(GraphObj *graph, Tensor input, Tensor output, float alpha); + OP_CLONE(EluObj); + + optional> inferShape(const TensorVec &inputs) override; + std::string toString() const override; + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return 1; } + float getAlpha() const { return alpha; } + float alpha; + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + class HardtanhObj : public OperatorObj { public: HardtanhObj(GraphObj *graph, Tensor input, Tensor output, float min, diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 19b2f22b..17da96be 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -180,6 +180,12 @@ class OnnxStub: d[0], d[1], ) + elif node.op_type == "Elu": + attributes = _parse_attribute(node, {"alpha": 1.0}) + alpha = attributes["alpha"] + tensors[node.output[0]] = self.handler.elu( + tensors[node.input[0]], tensors.get(node.output[0]), alpha + ) elif node.op_type == "ConvTranspose": attributes = _parse_attribute( node, @@ -1174,6 +1180,13 @@ class OnnxStub: group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1], ) ) + elif ty == backend.OpTypeId.Elu: + alpha = backend.elu_alpha_of(op) + ctx.push_node( + make_node( + "Elu", inputs, outputs, name, alpha=alpha + ) + ) elif ty == backend.OpTypeId.ConvTranspose: ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op) ctx.push_node( diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 87294623..f0d1687f 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -36,6 +36,17 @@ namespace infini { static DataType dtype_repr_convert(int); static CastType inferCastType(Tensor input, int to); +Tensor GraphHandlerObj::elu(Tensor input, Tensor output, float alpha) { + if (output) { + g->addOpWithOutputs(std::move(input), output, alpha); + return output; + } else { + auto new_output = g->addTensor(input->getDims(), input->getDType()); + g->addOpWithOutputs(std::move(input), new_output, alpha); + return new_output; + } +} + Tensor GraphHandlerObj::tensor(Shape dims, int dtype) { return g->addTensor(std::move(dims), dtype_repr_convert(dtype)); } diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 0d22c2dc..a106074f 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -120,6 +120,7 @@ void export_values(py::module &m) { .VALUE(OpType, Where) .VALUE(OpType, DepthToSpace) .VALUE(OpType, LRN) + .VALUE(OpType, Elu) .export_values(); #undef VALUE @@ -203,6 +204,12 @@ static std::tuple matmul_attrs_of(Operator op) { return std::make_tuple(matmul->getTransA(), matmul->getTransB()); } +static float elu_alpha_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Elu); + auto elu = dynamic_cast(op.get()); + return elu->getAlpha(); +} + static std::tuple batch_norm_attrs_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::BatchNormalization); auto batchnorm = dynamic_cast(op.get()); @@ -368,7 +375,8 @@ void export_functions(py::module &m) { .FUNCTION(depth_to_space_attrs_of) .FUNCTION(squeeze_axes_of) .FUNCTION(unsqueeze_axes_of) - .FUNCTION(lrn_attrs_of); + .FUNCTION(lrn_attrs_of) + .FUNCTION(elu_alpha_of); #undef FUNCTION } @@ -501,6 +509,7 @@ void init_graph_builder(py::module &m) { policy::reference); py::class_(m, "GraphHandler") .def(py::init()) + .def("elu", &Handler::elu, policy::move) .def("tensor", &Handler::tensor, policy::move) .def("conv", &Handler::conv, policy::move) .def("convTransposed2d", &Handler::convTransposed2d, policy::move) diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index 1ba1ecf6..9f8018e7 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -13,6 +13,20 @@ class UnaryCuda : public CudaKernelWithoutConfig { } }; +class EluCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + + size_t size = op->getInputs(0)->size(); + elu_kernel((float *)inputData, (float *)outputData, size, + op->getAlpha()); + } +}; + class CastCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { @@ -192,6 +206,7 @@ class TanhCudnn : public ActivationCudnn { REGISTER_KERNEL(Device::CUDA, OpType::Relu, ReluCudnn, "Relu_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, SigmoidCudnn, "Sigmoid_CUDA"); +REGISTER_KERNEL(Device::CUDA, OpType::Elu, EluCuda, "Elu_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::HardSigmoid, UnaryCuda, "Hard_Sigmoid_CUDA"); REGISTER_KERNEL(Device::CUDA, OpType::HardSwish, UnaryCuda, "Hard_Swish_CUDA"); diff --git a/src/kernels/cuda/unary.cu b/src/kernels/cuda/unary.cu index 3871d9a0..077406a3 100644 --- a/src/kernels/cuda/unary.cu +++ b/src/kernels/cuda/unary.cu @@ -94,6 +94,15 @@ __global__ void _sqrt_kernel(half *input, half *output, size_t n) { } } +__global__ void _elu_kernel(const float *input, float *output, int size, float alpha) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index < size) { + float x = input[index]; + output[index] = (x >= 0) ? x : alpha * (expf(x) - 1); + } +} + template __global__ void _gelu_kernel(T *input, T *output, size_t n) { int index = threadIdx.x + blockIdx.x * blockDim.x; @@ -360,6 +369,12 @@ void leaky_relu_kernel(T *input, T *output, size_t num, float alphaValue) { alphaValue); } +void elu_kernel(const float *input, float *output, int size, float alpha) { + int blocksize = 32 * 16; + int gridsize = (size + blocksize - 1) / blocksize; + _elu_kernel<<>>(input, output, size, alpha); +} + template void cast_kernel(float *input, half *output, size_t num); template void cast_kernel(half *input, float *output, size_t num); template void cast_kernel(float *input, int32_t *output, diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 404291a9..6dbf73c3 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -342,4 +342,33 @@ vector LogObj::getWorkloadVector() const { vector LogObj::getOpAttrVector() const { return {type.underlying()}; } +EluObj::EluObj(GraphObj *graph, Tensor input, Tensor output, float alpha) + : OperatorObj(OpType::Elu, {input}, {output}), alpha(alpha) { + IT_ASSERT(checkValid(graph)); +} + +optional> EluObj::inferShape(const TensorVec &inputs) { + return {{inputs[0]->getDims()}}; +} + +std::string EluObj::toString() const { + std::ostringstream os; + os << "Elu[" << getGuid() << "]"; + os << "("; + os << "input=" << inputs[0]->getGuid() << ","; + os << "alpha=" << alpha << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector EluObj::getWorkloadVector() const { + vector ret = getOutput()->getDims(); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector EluObj::getOpAttrVector() const { + return {type.underlying(), static_cast(alpha)}; +} + }; // namespace infini diff --git a/test/kernels/cuda/test_cuda_unary.cc b/test/kernels/cuda/test_cuda_unary.cc index 07b10616..ea7b29a0 100644 --- a/test/kernels/cuda/test_cuda_unary.cc +++ b/test/kernels/cuda/test_cuda_unary.cc @@ -96,6 +96,29 @@ TEST(LeakyRelu, Cuda_WithAlpha) { -0.015, -0.01, 1.0, 2.0, 3.0})); } +TEST(Elu, Cuda) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + auto op = gCuda->addOp(inputGpu, nullptr, 1.0f); + gCuda->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); + oCpu->printData(); + EXPECT_TRUE(oCpu->equalData( + vector{0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.})); +} + TEST(cuDNN_Unary, run) { testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary(IncrementalGenerator(), Shape{1, 2, 2, 3});