From d1c913010f9a969a5b1a36357883ee08bd5b119a Mon Sep 17 00:00:00 2001 From: wendy12022 Date: Sat, 15 Oct 2022 16:53:58 +0800 Subject: [PATCH] ADD:reduce_mean operator and cuda kernel. (#47) add new line at file ending. --- include/operators/reduce_mean.h | 27 +++++ src/kernels/cuda/element_wise.cc | 9 +- src/kernels/cuda/pooling.cc | 8 +- src/kernels/cuda/reduce_mean.cc | 111 +++++++++++++++++++++ src/operators/reduce_mean.cc | 85 ++++++++++++++++ test/kernels/cuda/test_cuda_pad.cc | 2 +- test/kernels/cuda/test_cuda_reduce_mean.cc | 62 ++++++++++++ test/operators/test_reduce_mean.cc | 38 +++++++ 8 files changed, 331 insertions(+), 11 deletions(-) create mode 100644 include/operators/reduce_mean.h create mode 100644 src/kernels/cuda/reduce_mean.cc create mode 100644 src/operators/reduce_mean.cc create mode 100644 test/kernels/cuda/test_cuda_reduce_mean.cc create mode 100644 test/operators/test_reduce_mean.cc diff --git a/include/operators/reduce_mean.h b/include/operators/reduce_mean.h new file mode 100644 index 00000000..ba631e36 --- /dev/null +++ b/include/operators/reduce_mean.h @@ -0,0 +1,27 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +class ReduceMeanObj : public OperatorObj { + set axis; // axis to reduce + bool keepDims; + + public: + ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, + const optional> &axis, + bool keepDims = true); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + bool isReduced(int idx) const; + bool getKeepDims() const { return keepDims; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index c44287e7..97835746 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -45,11 +45,10 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig { opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN)); auto [aAlpha, bAlpha, beta] = getAlphBeta(); - cudnnStatus_t stat = - cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha, aDesc, aData, - &bAlpha, bDesc, bData, &beta, cDesc, cData); - if (stat != CUDNN_STATUS_SUCCESS) - return; + + checkCudnnError(cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha, + aDesc, aData, &bAlpha, bDesc, bData, + &beta, cDesc, cData)); // Destories in CUDA does not require sync. But cuDNN does not state // whether sync is required before destories. diff --git a/src/kernels/cuda/pooling.cc b/src/kernels/cuda/pooling.cc index 47cf32e7..552690b9 100644 --- a/src/kernels/cuda/pooling.cc +++ b/src/kernels/cuda/pooling.cc @@ -9,7 +9,6 @@ class poolingCudnn : public CudaKernelWithoutConfig { const RuntimeObj *_context) const override { auto op = as(_op); auto context = dynamic_cast(_context); - cudnnStatus_t stat; void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const outData = (op->getOutput()->getRawDataPtr()); @@ -43,10 +42,9 @@ class poolingCudnn : public CudaKernelWithoutConfig { "cuDNN output shape mismatches with OP output shape"); float alpha = 1.f, beta = 0.f; - stat = cudnnPoolingForward(context->cudnnHandle(), poolingDesc, &alpha, - inDesc, inData, &beta, outDesc, outData); - if (stat != CUDNN_STATUS_SUCCESS) - return; + checkCudnnError(cudnnPoolingForward(context->cudnnHandle(), poolingDesc, + &alpha, inDesc, inData, &beta, + outDesc, outData)); // Destories in CUDA does not require sync. But cuDNN does not state // whether sync is required before destories. diff --git a/src/kernels/cuda/reduce_mean.cc b/src/kernels/cuda/reduce_mean.cc new file mode 100644 index 00000000..e61f0019 --- /dev/null +++ b/src/kernels/cuda/reduce_mean.cc @@ -0,0 +1,111 @@ +#include "operators/reduce_mean.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" + +namespace infini { +class ReduceMeanCudnn : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto input = op->getInputs(0); + auto output = op->getOutput(); + auto context = dynamic_cast(_context); + + // Each dimension of the output tensor C must match the corresponding + // dimension of the input tensor A or must be equal to 1. The dimensions + // equal to 1 indicate the dimensions of A to be reduced. + int nInDims = input->getDims().size(); + IT_ASSERT(CUDNN_DIM_MAX >= nInDims); + int inDimArray[CUDNN_DIM_MAX], outDimArray[CUDNN_DIM_MAX], + inStrideArray[CUDNN_DIM_MAX], outStrideArray[CUDNN_DIM_MAX]; + for (int i = 0; i < nInDims; ++i) { + inDimArray[i] = input->getDims()[i]; + inStrideArray[i] = input->getStride()[i]; + } + Shape d = output->getDims(); + if (!op->getKeepDims()) { + d = input->getDims(); + for (size_t i = 0; i < d.size(); ++i) + if (op->isReduced(i)) + d[i] = 1; + } + int stride = 1; + for (int i = nInDims - 1; i >= 0; --i) { + outDimArray[i] = d[i]; + outStrideArray[i] = stride; + stride *= d[i]; + } + + // cudnnSetTensorNdDescriptor is used when nDim>3, otherwise,it is + // recomended to use cudnnSetTensor4dDescriptor and set the unused + // dimension size to 1. + // get inputs outputs + cudnnTensorDescriptor_t inDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); + cudnnTensorDescriptor_t outDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); + if (nInDims > 3) { + checkCudnnError(cudnnSetTensorNdDescriptor( + inDesc, CUDNN_DATA_FLOAT, nInDims, inDimArray, inStrideArray)); + checkCudnnError( + cudnnSetTensorNdDescriptor(outDesc, CUDNN_DATA_FLOAT, nInDims, + outDimArray, outStrideArray)); + } else { + int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1}; + for (int i = 0; i < nInDims; ++i) { + idims[4 - i - 1] = input->getDims()[nInDims - i - 1]; + } + for (int i = 0; i < nInDims; ++i) { + odims[4 - i - 1] = d[nInDims - i - 1]; + } + + checkCudnnError(cudnnSetTensor4dDescriptor( + inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, idims[0], idims[1], + idims[2], idims[3])); + checkCudnnError(cudnnSetTensor4dDescriptor( + outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, odims[0], + odims[1], odims[2], odims[3])); + } + + // get reduce descriptor + cudnnReduceTensorDescriptor_t reduceDesc; + checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc)); + checkCudnnError(cudnnSetReduceTensorDescriptor( + reduceDesc, CUDNN_REDUCE_TENSOR_AVG, CUDNN_DATA_FLOAT, + CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, + CUDNN_32BIT_INDICES)); + + // get workspace + size_t workspaceSize = 0; + checkCudnnError( + cudnnGetReductionWorkspaceSize(context->cudnnHandle(), reduceDesc, + inDesc, outDesc, &workspaceSize)); + CudaPtr wsData = context->getWorkspace(workspaceSize); + + // get index workspace + size_t idxWorkspaceSize = 0; + checkCudnnError( + cudnnGetReductionIndicesSize(context->cudnnHandle(), reduceDesc, + inDesc, outDesc, &idxWorkspaceSize)); + CudaPtr idxWsData = context->getWorkspace(idxWorkspaceSize); + + // reduce + float alpha = 1.f, beta = 0.f; + void *const inData = (input->getRawDataPtr()); + void *const outData = (output->getRawDataPtr()); + checkCudnnError(cudnnReduceTensor(context->cudnnHandle(), reduceDesc, + idxWsData, idxWorkspaceSize, wsData, + workspaceSize, &alpha, inDesc, inData, + &beta, outDesc, outData)); + + // Destories in CUDA does not require sync. But cuDNN does not state + // whether sync is required before destories. + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); + checkCudnnError(cudnnDestroyReduceTensorDescriptor(reduceDesc)); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32, + ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32"); +}; // namespace infini diff --git a/src/operators/reduce_mean.cc b/src/operators/reduce_mean.cc new file mode 100644 index 00000000..3e627102 --- /dev/null +++ b/src/operators/reduce_mean.cc @@ -0,0 +1,85 @@ +#include "operators/reduce_mean.h" + +namespace infini { +ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, + const optional> &_axis, + bool keepDims) + : OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) { + + if (_axis != std::nullopt) { + IT_ASSERT((*_axis).size() <= input->getDims().size()); + for (size_t j = 0; j < (*_axis).size(); ++j) { + int idx = (*_axis)[j]; + if (idx < 0) + IT_TODO_HALT(); + IT_ASSERT((size_t)idx < input->getDims().size()); + axis.emplace(idx); + } + } else + for (size_t i = 0; i < input->getDims().size(); ++i) + axis.emplace(i); + IT_ASSERT(checkValid(graph)); +} + +bool ReduceMeanObj::isReduced(int idx) const { + return axis.find(idx) != axis.end(); +} + +optional> +ReduceMeanObj::inferShape(const TensorVec &inputs) const { + auto dims = inputs[0]->getDims(); + + if (keepDims) { + Shape ret = dims; + for (auto it : axis) + ret[it] = 1; + return {{ret}}; + } else { + Shape ret; + for (size_t i = 0; i < dims.size(); ++i) { + if (!isReduced(i)) + ret.emplace_back(dims[i]); + } + if (ret.size() == (size_t)0) + ret.emplace_back(1); + return {{ret}}; + } +} + +std::string ReduceMeanObj::toString() const { + std::ostringstream os; + os << "ReduceMean" + << "[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + + std::string axisstr; + axisstr.append("["); + for (auto d : axis) { + axisstr.append(std::to_string(d)); + axisstr.append(","); + } + if (!axis.empty()) + axisstr.pop_back(); + axisstr.append("]"); + os << "axis=" << axisstr << ","; + os << "keepDims=" << keepDims << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector ReduceMeanObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.emplace(ret.begin(), enum_to_underlying(type)); + ret.emplace_back((int)keepDims); + ret.insert(ret.end(), axis.begin(), axis.end()); + return ret; +} + +vector ReduceMeanObj::getOpAttrVector() const { + vector ret = {enum_to_underlying(type), (int)keepDims}; + ret.insert(ret.end(), axis.begin(), axis.end()); + return ret; +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_pad.cc b/test/kernels/cuda/test_cuda_pad.cc index c0a96f94..e157114d 100644 --- a/test/kernels/cuda/test_cuda_pad.cc +++ b/test/kernels/cuda/test_cuda_pad.cc @@ -31,7 +31,7 @@ TEST(Pad, Cuda) { // clone CUDA output to CPU auto o = op->getOutput(); auto cpuo = o->clone(cpuRuntime); - // cudaPrintTensor(o); + // check results on CPU EXPECT_TRUE(cpuo->equalData( vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/test/kernels/cuda/test_cuda_reduce_mean.cc b/test/kernels/cuda/test_cuda_reduce_mean.cc new file mode 100644 index 00000000..90356994 --- /dev/null +++ b/test/kernels/cuda/test_cuda_reduce_mean.cc @@ -0,0 +1,62 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/reduce_mean.h" + +#include "test.h" + +namespace infini { + +void test_reducemean(const Shape &shape, const vector &data, + const optional> &axis, bool keepDims, + const vector &ExpectData) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor icpu = make_ref(shape, DataType::Float32, cpuRuntime); + icpu->dataMalloc(); + icpu->copyData(data); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto i = g->cloneTensor(icpu); + auto op = g->addOp(i, nullptr, axis, keepDims); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o = op->getOutput(); + auto ocpu = o->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(ocpu->equalData(ExpectData)); +} + +TEST(CUDA_ReduceMean, run) { + test_reducemean(Shape{3, 2, 2}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, true, vector{18.25}); + test_reducemean(Shape{1, 3, 2, 2, 1}, + vector{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2}, + std::nullopt, false, vector{18.25}); + + test_reducemean(Shape{2, 3, 2, 2}, + vector{0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, false, vector{5, 6, 17, 18}); + test_reducemean(Shape{2, 3, 2, 2, 1}, + vector{0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + vector{1, 2}, true, vector{5, 6, 17, 18}); +} + +} // namespace infini diff --git a/test/operators/test_reduce_mean.cc b/test/operators/test_reduce_mean.cc new file mode 100644 index 00000000..c6f0784a --- /dev/null +++ b/test/operators/test_reduce_mean.cc @@ -0,0 +1,38 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/reduce_mean.h" + +#include "test.h" + +namespace infini { + +TEST(ReduceMean, ShapeInference) { + Runtime runtime = CpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, std::nullopt, true); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 1})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, vector{1, 3}, true); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, std::nullopt, false); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, vector{1, 3}, false); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3})); + } +} + +} // namespace infini