ADD:reduce_mean operator and cuda kernel. (#47)

add new line at file ending.
This commit is contained in:
wendy12022 2022-10-15 16:53:58 +08:00 committed by GitHub
parent a4d6426589
commit d1c913010f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 331 additions and 11 deletions

View File

@ -0,0 +1,27 @@
#pragma once
#include "core/operator.h"
namespace infini {
class ReduceMeanObj : public OperatorObj {
set<int> axis; // axis to reduce
bool keepDims;
public:
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &axis,
bool keepDims = true);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
bool isReduced(int idx) const;
bool getKeepDims() const { return keepDims; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -45,11 +45,10 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN)); opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
auto [aAlpha, bAlpha, beta] = getAlphBeta(); auto [aAlpha, bAlpha, beta] = getAlphBeta();
cudnnStatus_t stat =
cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha, aDesc, aData, checkCudnnError(cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha,
&bAlpha, bDesc, bData, &beta, cDesc, cData); aDesc, aData, &bAlpha, bDesc, bData,
if (stat != CUDNN_STATUS_SUCCESS) &beta, cDesc, cData));
return;
// Destories in CUDA does not require sync. But cuDNN does not state // Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories. // whether sync is required before destories.

View File

@ -9,7 +9,6 @@ class poolingCudnn : public CudaKernelWithoutConfig {
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<PoolingObj>(_op); auto op = as<PoolingObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context); auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
cudnnStatus_t stat;
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>()); void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>()); void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -43,10 +42,9 @@ class poolingCudnn : public CudaKernelWithoutConfig {
"cuDNN output shape mismatches with OP output shape"); "cuDNN output shape mismatches with OP output shape");
float alpha = 1.f, beta = 0.f; float alpha = 1.f, beta = 0.f;
stat = cudnnPoolingForward(context->cudnnHandle(), poolingDesc, &alpha, checkCudnnError(cudnnPoolingForward(context->cudnnHandle(), poolingDesc,
inDesc, inData, &beta, outDesc, outData); &alpha, inDesc, inData, &beta,
if (stat != CUDNN_STATUS_SUCCESS) outDesc, outData));
return;
// Destories in CUDA does not require sync. But cuDNN does not state // Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories. // whether sync is required before destories.

View File

@ -0,0 +1,111 @@
#include "operators/reduce_mean.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class ReduceMeanCudnn : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ReduceMeanObj>(_op);
auto input = op->getInputs(0);
auto output = op->getOutput();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
// Each dimension of the output tensor C must match the corresponding
// dimension of the input tensor A or must be equal to 1. The dimensions
// equal to 1 indicate the dimensions of A to be reduced.
int nInDims = input->getDims().size();
IT_ASSERT(CUDNN_DIM_MAX >= nInDims);
int inDimArray[CUDNN_DIM_MAX], outDimArray[CUDNN_DIM_MAX],
inStrideArray[CUDNN_DIM_MAX], outStrideArray[CUDNN_DIM_MAX];
for (int i = 0; i < nInDims; ++i) {
inDimArray[i] = input->getDims()[i];
inStrideArray[i] = input->getStride()[i];
}
Shape d = output->getDims();
if (!op->getKeepDims()) {
d = input->getDims();
for (size_t i = 0; i < d.size(); ++i)
if (op->isReduced(i))
d[i] = 1;
}
int stride = 1;
for (int i = nInDims - 1; i >= 0; --i) {
outDimArray[i] = d[i];
outStrideArray[i] = stride;
stride *= d[i];
}
// cudnnSetTensorNdDescriptor is used when nDim>3, otherwise,it is
// recomended to use cudnnSetTensor4dDescriptor and set the unused
// dimension size to 1.
// get inputs outputs
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
if (nInDims > 3) {
checkCudnnError(cudnnSetTensorNdDescriptor(
inDesc, CUDNN_DATA_FLOAT, nInDims, inDimArray, inStrideArray));
checkCudnnError(
cudnnSetTensorNdDescriptor(outDesc, CUDNN_DATA_FLOAT, nInDims,
outDimArray, outStrideArray));
} else {
int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1};
for (int i = 0; i < nInDims; ++i) {
idims[4 - i - 1] = input->getDims()[nInDims - i - 1];
}
for (int i = 0; i < nInDims; ++i) {
odims[4 - i - 1] = d[nInDims - i - 1];
}
checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, idims[0], idims[1],
idims[2], idims[3]));
checkCudnnError(cudnnSetTensor4dDescriptor(
outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, odims[0],
odims[1], odims[2], odims[3]));
}
// get reduce descriptor
cudnnReduceTensorDescriptor_t reduceDesc;
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
checkCudnnError(cudnnSetReduceTensorDescriptor(
reduceDesc, CUDNN_REDUCE_TENSOR_AVG, CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
CUDNN_32BIT_INDICES));
// get workspace
size_t workspaceSize = 0;
checkCudnnError(
cudnnGetReductionWorkspaceSize(context->cudnnHandle(), reduceDesc,
inDesc, outDesc, &workspaceSize));
CudaPtr wsData = context->getWorkspace(workspaceSize);
// get index workspace
size_t idxWorkspaceSize = 0;
checkCudnnError(
cudnnGetReductionIndicesSize(context->cudnnHandle(), reduceDesc,
inDesc, outDesc, &idxWorkspaceSize));
CudaPtr idxWsData = context->getWorkspace(idxWorkspaceSize);
// reduce
float alpha = 1.f, beta = 0.f;
void *const inData = (input->getRawDataPtr<void *>());
void *const outData = (output->getRawDataPtr<void *>());
checkCudnnError(cudnnReduceTensor(context->cudnnHandle(), reduceDesc,
idxWsData, idxWorkspaceSize, wsData,
workspaceSize, &alpha, inDesc, inData,
&beta, outDesc, outData));
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyReduceTensorDescriptor(reduceDesc));
}
};
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32,
ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,85 @@
#include "operators/reduce_mean.h"
namespace infini {
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &_axis,
bool keepDims)
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
if (_axis != std::nullopt) {
IT_ASSERT((*_axis).size() <= input->getDims().size());
for (size_t j = 0; j < (*_axis).size(); ++j) {
int idx = (*_axis)[j];
if (idx < 0)
IT_TODO_HALT();
IT_ASSERT((size_t)idx < input->getDims().size());
axis.emplace(idx);
}
} else
for (size_t i = 0; i < input->getDims().size(); ++i)
axis.emplace(i);
IT_ASSERT(checkValid(graph));
}
bool ReduceMeanObj::isReduced(int idx) const {
return axis.find(idx) != axis.end();
}
optional<vector<Shape>>
ReduceMeanObj::inferShape(const TensorVec &inputs) const {
auto dims = inputs[0]->getDims();
if (keepDims) {
Shape ret = dims;
for (auto it : axis)
ret[it] = 1;
return {{ret}};
} else {
Shape ret;
for (size_t i = 0; i < dims.size(); ++i) {
if (!isReduced(i))
ret.emplace_back(dims[i]);
}
if (ret.size() == (size_t)0)
ret.emplace_back(1);
return {{ret}};
}
}
std::string ReduceMeanObj::toString() const {
std::ostringstream os;
os << "ReduceMean"
<< "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
std::string axisstr;
axisstr.append("[");
for (auto d : axis) {
axisstr.append(std::to_string(d));
axisstr.append(",");
}
if (!axis.empty())
axisstr.pop_back();
axisstr.append("]");
os << "axis=" << axisstr << ",";
os << "keepDims=" << keepDims << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> ReduceMeanObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.emplace(ret.begin(), enum_to_underlying(type));
ret.emplace_back((int)keepDims);
ret.insert(ret.end(), axis.begin(), axis.end());
return ret;
}
vector<int> ReduceMeanObj::getOpAttrVector() const {
vector<int> ret = {enum_to_underlying(type), (int)keepDims};
ret.insert(ret.end(), axis.begin(), axis.end());
return ret;
}
} // namespace infini

View File

@ -31,7 +31,7 @@ TEST(Pad, Cuda) {
// clone CUDA output to CPU // clone CUDA output to CPU
auto o = op->getOutput(); auto o = op->getOutput();
auto cpuo = o->clone(cpuRuntime); auto cpuo = o->clone(cpuRuntime);
// cudaPrintTensor(o);
// check results on CPU // check results on CPU
EXPECT_TRUE(cpuo->equalData( EXPECT_TRUE(cpuo->equalData(
vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

View File

@ -0,0 +1,62 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/reduce_mean.h"
#include "test.h"
namespace infini {
void test_reducemean(const Shape &shape, const vector<float> &data,
const optional<const vector<int>> &axis, bool keepDims,
const vector<float> &ExpectData) {
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
// Build input data on CPU
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
icpu->dataMalloc();
icpu->copyData(data);
// Build CUDA graph
Graph g = make_ref<GraphObj>(cudaRuntime);
auto i = g->cloneTensor(icpu);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, axis, keepDims);
// allocate CUDA memory
g->dataMalloc();
// Execute on CUDA
cudaRuntime->run(g);
// clone CUDA output to CPU
auto o = op->getOutput();
auto ocpu = o->clone(cpuRuntime);
// check results on CPU
EXPECT_TRUE(ocpu->equalData(ExpectData));
}
TEST(CUDA_ReduceMean, run) {
test_reducemean(Shape{3, 2, 2},
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
std::nullopt, true, vector<float>{18.25});
test_reducemean(Shape{1, 3, 2, 2, 1},
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
std::nullopt, false, vector<float>{18.25});
test_reducemean(Shape{2, 3, 2, 2},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
test_reducemean(Shape{2, 3, 2, 2, 1},
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23},
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
}
} // namespace infini

View File

@ -0,0 +1,38 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/reduce_mean.h"
#include "test.h"
namespace infini {
TEST(ReduceMean, ShapeInference) {
Runtime runtime = CpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, true);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 1}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, true);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, false);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, false);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
}
}
} // namespace infini