forked from jiuyuan/InfiniTensor
ADD:reduce_mean operator and cuda kernel. (#47)
add new line at file ending.
This commit is contained in:
parent
a4d6426589
commit
d1c913010f
|
@ -0,0 +1,27 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class ReduceMeanObj : public OperatorObj {
|
||||||
|
set<int> axis; // axis to reduce
|
||||||
|
bool keepDims;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
const optional<const vector<int>> &axis,
|
||||||
|
bool keepDims = true);
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
bool isReduced(int idx) const;
|
||||||
|
bool getKeepDims() const { return keepDims; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -45,11 +45,10 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
|
||||||
opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
|
opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
|
||||||
|
|
||||||
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
||||||
cudnnStatus_t stat =
|
|
||||||
cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha, aDesc, aData,
|
checkCudnnError(cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha,
|
||||||
&bAlpha, bDesc, bData, &beta, cDesc, cData);
|
aDesc, aData, &bAlpha, bDesc, bData,
|
||||||
if (stat != CUDNN_STATUS_SUCCESS)
|
&beta, cDesc, cData));
|
||||||
return;
|
|
||||||
|
|
||||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||||
// whether sync is required before destories.
|
// whether sync is required before destories.
|
||||||
|
|
|
@ -9,7 +9,6 @@ class poolingCudnn : public CudaKernelWithoutConfig {
|
||||||
const RuntimeObj *_context) const override {
|
const RuntimeObj *_context) const override {
|
||||||
auto op = as<PoolingObj>(_op);
|
auto op = as<PoolingObj>(_op);
|
||||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
cudnnStatus_t stat;
|
|
||||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
@ -43,10 +42,9 @@ class poolingCudnn : public CudaKernelWithoutConfig {
|
||||||
"cuDNN output shape mismatches with OP output shape");
|
"cuDNN output shape mismatches with OP output shape");
|
||||||
|
|
||||||
float alpha = 1.f, beta = 0.f;
|
float alpha = 1.f, beta = 0.f;
|
||||||
stat = cudnnPoolingForward(context->cudnnHandle(), poolingDesc, &alpha,
|
checkCudnnError(cudnnPoolingForward(context->cudnnHandle(), poolingDesc,
|
||||||
inDesc, inData, &beta, outDesc, outData);
|
&alpha, inDesc, inData, &beta,
|
||||||
if (stat != CUDNN_STATUS_SUCCESS)
|
outDesc, outData));
|
||||||
return;
|
|
||||||
|
|
||||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||||
// whether sync is required before destories.
|
// whether sync is required before destories.
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
#include "operators/reduce_mean.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class ReduceMeanCudnn : public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<ReduceMeanObj>(_op);
|
||||||
|
auto input = op->getInputs(0);
|
||||||
|
auto output = op->getOutput();
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
// Each dimension of the output tensor C must match the corresponding
|
||||||
|
// dimension of the input tensor A or must be equal to 1. The dimensions
|
||||||
|
// equal to 1 indicate the dimensions of A to be reduced.
|
||||||
|
int nInDims = input->getDims().size();
|
||||||
|
IT_ASSERT(CUDNN_DIM_MAX >= nInDims);
|
||||||
|
int inDimArray[CUDNN_DIM_MAX], outDimArray[CUDNN_DIM_MAX],
|
||||||
|
inStrideArray[CUDNN_DIM_MAX], outStrideArray[CUDNN_DIM_MAX];
|
||||||
|
for (int i = 0; i < nInDims; ++i) {
|
||||||
|
inDimArray[i] = input->getDims()[i];
|
||||||
|
inStrideArray[i] = input->getStride()[i];
|
||||||
|
}
|
||||||
|
Shape d = output->getDims();
|
||||||
|
if (!op->getKeepDims()) {
|
||||||
|
d = input->getDims();
|
||||||
|
for (size_t i = 0; i < d.size(); ++i)
|
||||||
|
if (op->isReduced(i))
|
||||||
|
d[i] = 1;
|
||||||
|
}
|
||||||
|
int stride = 1;
|
||||||
|
for (int i = nInDims - 1; i >= 0; --i) {
|
||||||
|
outDimArray[i] = d[i];
|
||||||
|
outStrideArray[i] = stride;
|
||||||
|
stride *= d[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// cudnnSetTensorNdDescriptor is used when nDim>3, otherwise,it is
|
||||||
|
// recomended to use cudnnSetTensor4dDescriptor and set the unused
|
||||||
|
// dimension size to 1.
|
||||||
|
// get inputs outputs
|
||||||
|
cudnnTensorDescriptor_t inDesc;
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||||
|
cudnnTensorDescriptor_t outDesc;
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||||
|
if (nInDims > 3) {
|
||||||
|
checkCudnnError(cudnnSetTensorNdDescriptor(
|
||||||
|
inDesc, CUDNN_DATA_FLOAT, nInDims, inDimArray, inStrideArray));
|
||||||
|
checkCudnnError(
|
||||||
|
cudnnSetTensorNdDescriptor(outDesc, CUDNN_DATA_FLOAT, nInDims,
|
||||||
|
outDimArray, outStrideArray));
|
||||||
|
} else {
|
||||||
|
int idims[4] = {1, 1, 1, 1}, odims[4] = {1, 1, 1, 1};
|
||||||
|
for (int i = 0; i < nInDims; ++i) {
|
||||||
|
idims[4 - i - 1] = input->getDims()[nInDims - i - 1];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < nInDims; ++i) {
|
||||||
|
odims[4 - i - 1] = d[nInDims - i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||||
|
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, idims[0], idims[1],
|
||||||
|
idims[2], idims[3]));
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||||
|
outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, odims[0],
|
||||||
|
odims[1], odims[2], odims[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get reduce descriptor
|
||||||
|
cudnnReduceTensorDescriptor_t reduceDesc;
|
||||||
|
checkCudnnError(cudnnCreateReduceTensorDescriptor(&reduceDesc));
|
||||||
|
checkCudnnError(cudnnSetReduceTensorDescriptor(
|
||||||
|
reduceDesc, CUDNN_REDUCE_TENSOR_AVG, CUDNN_DATA_FLOAT,
|
||||||
|
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES,
|
||||||
|
CUDNN_32BIT_INDICES));
|
||||||
|
|
||||||
|
// get workspace
|
||||||
|
size_t workspaceSize = 0;
|
||||||
|
checkCudnnError(
|
||||||
|
cudnnGetReductionWorkspaceSize(context->cudnnHandle(), reduceDesc,
|
||||||
|
inDesc, outDesc, &workspaceSize));
|
||||||
|
CudaPtr wsData = context->getWorkspace(workspaceSize);
|
||||||
|
|
||||||
|
// get index workspace
|
||||||
|
size_t idxWorkspaceSize = 0;
|
||||||
|
checkCudnnError(
|
||||||
|
cudnnGetReductionIndicesSize(context->cudnnHandle(), reduceDesc,
|
||||||
|
inDesc, outDesc, &idxWorkspaceSize));
|
||||||
|
CudaPtr idxWsData = context->getWorkspace(idxWorkspaceSize);
|
||||||
|
|
||||||
|
// reduce
|
||||||
|
float alpha = 1.f, beta = 0.f;
|
||||||
|
void *const inData = (input->getRawDataPtr<void *>());
|
||||||
|
void *const outData = (output->getRawDataPtr<void *>());
|
||||||
|
checkCudnnError(cudnnReduceTensor(context->cudnnHandle(), reduceDesc,
|
||||||
|
idxWsData, idxWorkspaceSize, wsData,
|
||||||
|
workspaceSize, &alpha, inDesc, inData,
|
||||||
|
&beta, outDesc, outData));
|
||||||
|
|
||||||
|
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||||
|
// whether sync is required before destories.
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||||
|
checkCudnnError(cudnnDestroyReduceTensorDescriptor(reduceDesc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::ReduceMean, DataType::Float32,
|
||||||
|
ReduceMeanCudnn, "ReduceMean_cuDNN_CUDA_Float32");
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,85 @@
|
||||||
|
#include "operators/reduce_mean.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||||
|
const optional<const vector<int>> &_axis,
|
||||||
|
bool keepDims)
|
||||||
|
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
|
||||||
|
|
||||||
|
if (_axis != std::nullopt) {
|
||||||
|
IT_ASSERT((*_axis).size() <= input->getDims().size());
|
||||||
|
for (size_t j = 0; j < (*_axis).size(); ++j) {
|
||||||
|
int idx = (*_axis)[j];
|
||||||
|
if (idx < 0)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
IT_ASSERT((size_t)idx < input->getDims().size());
|
||||||
|
axis.emplace(idx);
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
for (size_t i = 0; i < input->getDims().size(); ++i)
|
||||||
|
axis.emplace(i);
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReduceMeanObj::isReduced(int idx) const {
|
||||||
|
return axis.find(idx) != axis.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>>
|
||||||
|
ReduceMeanObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto dims = inputs[0]->getDims();
|
||||||
|
|
||||||
|
if (keepDims) {
|
||||||
|
Shape ret = dims;
|
||||||
|
for (auto it : axis)
|
||||||
|
ret[it] = 1;
|
||||||
|
return {{ret}};
|
||||||
|
} else {
|
||||||
|
Shape ret;
|
||||||
|
for (size_t i = 0; i < dims.size(); ++i) {
|
||||||
|
if (!isReduced(i))
|
||||||
|
ret.emplace_back(dims[i]);
|
||||||
|
}
|
||||||
|
if (ret.size() == (size_t)0)
|
||||||
|
ret.emplace_back(1);
|
||||||
|
return {{ret}};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ReduceMeanObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "ReduceMean"
|
||||||
|
<< "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
|
||||||
|
std::string axisstr;
|
||||||
|
axisstr.append("[");
|
||||||
|
for (auto d : axis) {
|
||||||
|
axisstr.append(std::to_string(d));
|
||||||
|
axisstr.append(",");
|
||||||
|
}
|
||||||
|
if (!axis.empty())
|
||||||
|
axisstr.pop_back();
|
||||||
|
axisstr.append("]");
|
||||||
|
os << "axis=" << axisstr << ",";
|
||||||
|
os << "keepDims=" << keepDims << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> ReduceMeanObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
|
ret.emplace_back((int)keepDims);
|
||||||
|
ret.insert(ret.end(), axis.begin(), axis.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> ReduceMeanObj::getOpAttrVector() const {
|
||||||
|
vector<int> ret = {enum_to_underlying(type), (int)keepDims};
|
||||||
|
ret.insert(ret.end(), axis.begin(), axis.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -31,7 +31,7 @@ TEST(Pad, Cuda) {
|
||||||
// clone CUDA output to CPU
|
// clone CUDA output to CPU
|
||||||
auto o = op->getOutput();
|
auto o = op->getOutput();
|
||||||
auto cpuo = o->clone(cpuRuntime);
|
auto cpuo = o->clone(cpuRuntime);
|
||||||
// cudaPrintTensor(o);
|
|
||||||
// check results on CPU
|
// check results on CPU
|
||||||
EXPECT_TRUE(cpuo->equalData(
|
EXPECT_TRUE(cpuo->equalData(
|
||||||
vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/reduce_mean.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void test_reducemean(const Shape &shape, const vector<float> &data,
|
||||||
|
const optional<const vector<int>> &axis, bool keepDims,
|
||||||
|
const vector<float> &ExpectData) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor icpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||||
|
icpu->dataMalloc();
|
||||||
|
icpu->copyData(data);
|
||||||
|
|
||||||
|
// Build CUDA graph
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto i = g->cloneTensor(icpu);
|
||||||
|
auto op = g->addOp<ReduceMeanObj>(i, nullptr, axis, keepDims);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
g->dataMalloc();
|
||||||
|
|
||||||
|
// Execute on CUDA
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
|
||||||
|
// clone CUDA output to CPU
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto ocpu = o->clone(cpuRuntime);
|
||||||
|
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(ocpu->equalData(ExpectData));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CUDA_ReduceMean, run) {
|
||||||
|
test_reducemean(Shape{3, 2, 2},
|
||||||
|
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||||
|
std::nullopt, true, vector<float>{18.25});
|
||||||
|
test_reducemean(Shape{1, 3, 2, 2, 1},
|
||||||
|
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||||
|
std::nullopt, false, vector<float>{18.25});
|
||||||
|
|
||||||
|
test_reducemean(Shape{2, 3, 2, 2},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
|
8, 9, 10, 11, 12, 13, 14, 15,
|
||||||
|
16, 17, 18, 19, 20, 21, 22, 23},
|
||||||
|
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
|
||||||
|
test_reducemean(Shape{2, 3, 2, 2, 1},
|
||||||
|
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
|
||||||
|
8, 9, 10, 11, 12, 13, 14, 15,
|
||||||
|
16, 17, 18, 19, 20, 21, 22, 23},
|
||||||
|
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,38 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/reduce_mean.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(ReduceMean, ShapeInference) {
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, true);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 1, 1, 1}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, true);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<ReduceMeanObj>(i, nullptr, std::nullopt, false);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, false);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue