forked from jiuyuan/InfiniTensor
Add maxpool and avgpool operators (#17)
* ADD:maxpool&&avgpool operators. add OperatorObj::getDType() clang format FIX:timeit API has changed. * Fix: Tensor::getInputs is const method * Chore Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
bd63f738dc
commit
48293576c0
|
@ -14,7 +14,7 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,9 @@ class KernelRegistry {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
Kernel *getKernel(const KernelAttrs &kernelAttrs) const {
|
||||||
return std::get<0>(kernels.at(kernelAttrs));
|
auto it = kernels.find(kernelAttrs);
|
||||||
|
IT_ASSERT(it != kernels.end(), "Kernel not found.");
|
||||||
|
return std::get<0>(it->second);
|
||||||
}
|
}
|
||||||
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {
|
||||||
return kernels.at(kernelAttrs);
|
return kernels.at(kernelAttrs);
|
||||||
|
|
|
@ -169,13 +169,14 @@ class OperatorObj : public Object {
|
||||||
const TensorVec &getInputs() const { return inputs; }
|
const TensorVec &getInputs() const { return inputs; }
|
||||||
// TensorVec getOutputs() { return outputs; }
|
// TensorVec getOutputs() { return outputs; }
|
||||||
const TensorVec &getOutputs() const { return outputs; }
|
const TensorVec &getOutputs() const { return outputs; }
|
||||||
Tensor getInputs(size_t i) { return inputs.at(i); }
|
Tensor getInputs(size_t i) const { return inputs.at(i); }
|
||||||
Tensor getOutput() const {
|
Tensor getOutput() const {
|
||||||
IT_ASSERT(outputs.size() == 1, "Unimplemented");
|
IT_ASSERT(outputs.size() == 1, "Unimplemented");
|
||||||
return outputs[0];
|
return outputs[0];
|
||||||
}
|
}
|
||||||
OpType getOpType() const { return type; }
|
OpType getOpType() const { return type; }
|
||||||
|
// HACK: set correct data type
|
||||||
|
DataType getDType() const { return getInputs(0)->getDType(); }
|
||||||
virtual int numInputs() const = 0;
|
virtual int numInputs() const = 0;
|
||||||
virtual int numOutputs() const = 0;
|
virtual int numOutputs() const = 0;
|
||||||
|
|
||||||
|
|
|
@ -49,8 +49,18 @@ class TensorObj : public TensorBaseObj {
|
||||||
void copyData(const Tensor &src) { copyData(src.get()); }
|
void copyData(const Tensor &src) { copyData(src.get()); }
|
||||||
void setData(
|
void setData(
|
||||||
const std::function<void(void *, size_t, DataType)> &generator) const {
|
const std::function<void(void *, size_t, DataType)> &generator) const {
|
||||||
|
IT_ASSERT(data != nullptr);
|
||||||
|
if (!runtime->isCpu()) {
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
generator(data->getPtr<void *>(), size(), dtype);
|
generator(data->getPtr<void *>(), size(), dtype);
|
||||||
}
|
}
|
||||||
|
Tensor clone(Runtime runtime) {
|
||||||
|
auto obj = make_ref<TensorObj>(shape, dtype, runtime);
|
||||||
|
obj->dataMalloc();
|
||||||
|
obj->copyData(this);
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
void printData() const;
|
void printData() const;
|
||||||
bool equalData(const Tensor &rhs) const;
|
bool equalData(const Tensor &rhs) const;
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class PoolingObj : public OperatorObj {
|
||||||
|
private:
|
||||||
|
int kh, kw;
|
||||||
|
int dh, dw;
|
||||||
|
int ph, pw;
|
||||||
|
int sh, sw;
|
||||||
|
int n, c, h, w;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PoolingObj(GraphObj *graph, OpType optype, Tensor input, Tensor output,
|
||||||
|
int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw);
|
||||||
|
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
int getKh() const { return kh; }
|
||||||
|
int getKw() const { return kw; }
|
||||||
|
int getDh() const { return dh; }
|
||||||
|
int getDw() const { return dw; }
|
||||||
|
int getPh() const { return ph; }
|
||||||
|
int getPw() const { return pw; }
|
||||||
|
int getSh() const { return sh; }
|
||||||
|
int getSw() const { return sw; }
|
||||||
|
|
||||||
|
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||||
|
auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MaxPoolObj : public PoolingObj {
|
||||||
|
public:
|
||||||
|
MaxPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
|
||||||
|
int dh, int dw, int ph, int pw, int sh, int sw)
|
||||||
|
: PoolingObj(graph, OpType::MaxPool, input, output, kh, kw, dh, dw, ph,
|
||||||
|
pw, sh, sw) {}
|
||||||
|
};
|
||||||
|
class AvgPoolObj : public PoolingObj {
|
||||||
|
public:
|
||||||
|
AvgPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw,
|
||||||
|
int dh, int dw, int ph, int pw, int sh, int sw)
|
||||||
|
: PoolingObj(graph, OpType::AvgPool, input, output, kh, kw, dh, dw, ph,
|
||||||
|
pw, sh, sw) {}
|
||||||
|
};
|
||||||
|
}; // namespace infini
|
|
@ -62,7 +62,7 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
||||||
IT_ASSERT(!outputs[i]);
|
IT_ASSERT(!outputs[i]);
|
||||||
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
||||||
}
|
}
|
||||||
} else { // if graph is not empty, check outputs match inferred shapes
|
} else { // if outputs have been created, check their shapes
|
||||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||||
if (shapes[i] != outputs[i]->getDims())
|
if (shapes[i] != outputs[i]->getDims())
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -22,9 +22,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||||
std::map<OpType, int> opCnt;
|
std::map<OpType, int> opCnt;
|
||||||
|
|
||||||
for (auto &op : graph->getOperators()) {
|
for (auto &op : graph->getOperators()) {
|
||||||
// HACK: set correct data type
|
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
|
||||||
auto kernelAttrs =
|
|
||||||
KernelAttrs{device, op->getOpType(), DataType::UInt32};
|
|
||||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||||
|
@ -72,9 +70,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
|
||||||
std::map<OpType, int> opCnt;
|
std::map<OpType, int> opCnt;
|
||||||
|
|
||||||
for (auto &op : graph->getOperators()) {
|
for (auto &op : graph->getOperators()) {
|
||||||
// HACK: set correct data type
|
auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()};
|
||||||
auto kernelAttrs =
|
|
||||||
KernelAttrs{device, op->getOpType(), DataType::UInt32};
|
|
||||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||||
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
#include "operators/pooling.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
template <typename T> class NativePooling : public Kernel {
|
||||||
|
virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih,
|
||||||
|
int iw, T *inptr) const = 0;
|
||||||
|
void compute(const Operator &_op, const PerfRecord &record,
|
||||||
|
const RuntimeObj *context) const override {
|
||||||
|
auto op = as<PoolingObj>(_op);
|
||||||
|
T *inptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||||
|
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||||
|
const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS();
|
||||||
|
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||||
|
if (dh != 1 || dw != 1)
|
||||||
|
IT_TODO_HALT(); // To support dailated pooling
|
||||||
|
auto outDim = op->getOutput()->getDims();
|
||||||
|
int oh = outDim[2], ow = outDim[3];
|
||||||
|
for (auto i = 0; i < n; i++) {
|
||||||
|
for (auto j = 0; j < c; j++) {
|
||||||
|
auto inoffset = i * (c * ih * iw) + j * ih * iw;
|
||||||
|
for (auto h = 0; h < oh; h++) {
|
||||||
|
for (auto w = 0; w < ow; w++) {
|
||||||
|
T val =
|
||||||
|
getPoolingValue(kh, kw, h * sh - ph, w * sw - pw,
|
||||||
|
ih, iw, inptr + inoffset);
|
||||||
|
auto outoffset =
|
||||||
|
w + h * ow + j * (oh * ow) + i * (c * oh * ow);
|
||||||
|
outptr[outoffset] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||||
|
compute(op, {}, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
PerfRecord tune(const Operator &op,
|
||||||
|
const RuntimeObj *context) const override {
|
||||||
|
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||||
|
return perfrcd;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> class NaiveMaxPool : public NativePooling<T> {
|
||||||
|
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
|
||||||
|
T *inptr) const override {
|
||||||
|
T maxval = 0;
|
||||||
|
for (auto k = 0; k < kh; k++) {
|
||||||
|
for (auto l = 0; l < kw; l++) {
|
||||||
|
auto inPosH = posh + k;
|
||||||
|
auto inPosW = posw + l;
|
||||||
|
if (inPosH < 0 || inPosH >= ih || inPosW < 0 || inPosW >= iw)
|
||||||
|
continue;
|
||||||
|
auto offset = (posh + k) * iw + posw + l;
|
||||||
|
auto val = inptr[offset];
|
||||||
|
if (maxval < val)
|
||||||
|
maxval = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxval;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> class NaiveAvgPool : public NativePooling<T> {
|
||||||
|
T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw,
|
||||||
|
T *inptr) const override {
|
||||||
|
T sum = 0;
|
||||||
|
for (auto k = 0; k < kh; k++) {
|
||||||
|
for (auto l = 0; l < kw; l++) {
|
||||||
|
auto inPosH = posh + k;
|
||||||
|
auto inPosW = posw + l;
|
||||||
|
if (inPosH < 0 || inPosH >= ih || inPosW < 0 || inPosW >= iw)
|
||||||
|
continue;
|
||||||
|
auto offset = (posh + k) * iw + posw + l;
|
||||||
|
sum += inptr[offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return T(sum / (kh * kw));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32,
|
||||||
|
NaiveMaxPool<uint32_t>, "maxPoolNaive_CPU_uint32");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32,
|
||||||
|
NaiveMaxPool<float>, "maxPoolNaive_CPU_float32");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::AvgPool, DataType::Float32,
|
||||||
|
NaiveAvgPool<float>, "AvgPoolNaive_CPU_float32");
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,89 @@
|
||||||
|
#include "operators/pooling.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class poolingCudnn : public Kernel {
|
||||||
|
virtual cudnnPoolingMode_t getPoolingMode() const = 0;
|
||||||
|
void compute(const Operator &_op, const PerfRecord &record,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<PoolingObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
cudnnStatus_t stat;
|
||||||
|
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
const auto [n, c, h, w, kh, kw] = op->getNCHWRS();
|
||||||
|
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||||
|
|
||||||
|
// get inputs
|
||||||
|
cudnnTensorDescriptor_t inDesc;
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||||
|
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||||
|
|
||||||
|
// get maxpool descriptor
|
||||||
|
cudnnPoolingDescriptor_t poolingDesc;
|
||||||
|
checkCudnnError(cudnnCreatePoolingDescriptor(&poolingDesc));
|
||||||
|
checkCudnnError(cudnnSetPooling2dDescriptor(
|
||||||
|
poolingDesc, getPoolingMode(), CUDNN_NOT_PROPAGATE_NAN, kh, kw, ph,
|
||||||
|
pw, sh, sw));
|
||||||
|
|
||||||
|
// get outputs
|
||||||
|
int outn, outc, outh, outw;
|
||||||
|
checkCudnnError(cudnnGetPooling2dForwardOutputDim(
|
||||||
|
poolingDesc, inDesc, &outn, &outc, &outh, &outw));
|
||||||
|
cudnnTensorDescriptor_t outDesc;
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||||
|
CUDNN_DATA_FLOAT, outn, outc,
|
||||||
|
outh, outw));
|
||||||
|
IT_ASSERT((vector{outn, outc, outh, outw}) ==
|
||||||
|
op->getOutput()->getDims(),
|
||||||
|
"cuDNN output shape mismatches with OP output shape");
|
||||||
|
|
||||||
|
float alpha = 1.f, beta = 0.f;
|
||||||
|
stat = cudnnPoolingForward(context->cudnnHandle(), poolingDesc, &alpha,
|
||||||
|
inDesc, inData, &beta, outDesc, outData);
|
||||||
|
if (stat != CUDNN_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||||
|
// whether sync is required before destories.
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||||
|
checkCudnnError(cudnnDestroyPoolingDescriptor(poolingDesc));
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
compute(_op, {}, _context);
|
||||||
|
}
|
||||||
|
// Premise: op is idempotent since it is called multiple times.
|
||||||
|
PerfRecord tune(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
PerfRecord ret;
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
ret.time = timeit([&]() { compute(_op, _context); },
|
||||||
|
[&]() { context->sync(); });
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class maxPoolCudnn : public poolingCudnn {
|
||||||
|
cudnnPoolingMode_t getPoolingMode() const override {
|
||||||
|
return CUDNN_POOLING_MAX;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class avgPoolCudnn : public poolingCudnn {
|
||||||
|
cudnnPoolingMode_t getPoolingMode() const override {
|
||||||
|
return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn,
|
||||||
|
"MaxPool_cuDNN_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::AvgPool, DataType::Float32, avgPoolCudnn,
|
||||||
|
"AvgPool_cuDNN_CUDA_Float32");
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,52 @@
|
||||||
|
#include "operators/pooling.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
||||||
|
Tensor output, int kh, int kw, int dh, int dw, int ph,
|
||||||
|
int pw, int sh, int sw)
|
||||||
|
: OperatorObj(optype, {input}, {output}), kh(kh), kw(kw), dh(dh), dw(dw),
|
||||||
|
ph(ph), pw(pw), sh(sh), sw(sw) {
|
||||||
|
n = input->getDims()[0];
|
||||||
|
c = input->getDims()[1];
|
||||||
|
h = input->getDims()[2], w = input->getDims()[3];
|
||||||
|
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
const auto &input = inputs[0];
|
||||||
|
auto h = input->getDims()[input->getDims().size() - 2],
|
||||||
|
w = input->getDims()[input->getDims().size() - 1];
|
||||||
|
int oh = (h - (kh - sh) + ph * 2) / sh;
|
||||||
|
int ow = (w - (kw - sw) + pw * 2) / sw;
|
||||||
|
auto ret = input->getDims();
|
||||||
|
ret[input->getDims().size() - 2] = oh;
|
||||||
|
ret[input->getDims().size() - 1] = ow;
|
||||||
|
return {{ret}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PoolingObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Maxpool[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << "k=[" << kh << "," << kw << "],";
|
||||||
|
os << "p=[" << ph << "," << pw << "],";
|
||||||
|
os << "s=[" << sh << "," << sw << "],";
|
||||||
|
os << "d=[" << dh << "," << dw << "],";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> PoolingObj::getWorkloadVector() const {
|
||||||
|
return {
|
||||||
|
enum_to_underlying(type), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw};
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> PoolingObj::getOpAttrVector() const {
|
||||||
|
IT_TODO_HALT();
|
||||||
|
return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw};
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -14,8 +14,8 @@ TEST(Hash, OperatorHash) {
|
||||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||||
key1 = matmul->getOpPerfKey();
|
key1 = matmul->getOpPerfKey();
|
||||||
EXPECT_NE(key1.hash, 0);
|
EXPECT_NE(key1.hash, (HashType)0);
|
||||||
EXPECT_GT(key1.attrs.size(), 5);
|
EXPECT_GT(key1.attrs.size(), (size_t)5);
|
||||||
}
|
}
|
||||||
{ // build with addOp
|
{ // build with addOp
|
||||||
Graph g = make_ref<GraphObj>(nullptr);
|
Graph g = make_ref<GraphObj>(nullptr);
|
||||||
|
@ -23,7 +23,7 @@ TEST(Hash, OperatorHash) {
|
||||||
Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32);
|
Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32);
|
||||||
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
|
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
|
||||||
key2 = matmul->getOpPerfKey();
|
key2 = matmul->getOpPerfKey();
|
||||||
EXPECT_NE(key2.hash, 0);
|
EXPECT_NE(key2.hash, (HashType)0);
|
||||||
}
|
}
|
||||||
EXPECT_NE(key1.hash, key2.hash);
|
EXPECT_NE(key1.hash, key2.hash);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/pooling.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
using KDPS = vector<int>;
|
||||||
|
using ExpectOutput = vector<float>;
|
||||||
|
TEST(MaxPool, ShapeInference) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||||
|
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
|
||||||
|
sw = 2;
|
||||||
|
auto op =
|
||||||
|
g->addOp<MaxPoolObj>(i, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
|
||||||
|
}
|
||||||
|
|
||||||
|
{ // dilation & stride
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||||
|
auto op = g->addOp<MaxPoolObj>(i, nullptr, 4, 3, 1, 1, 2, 1, 1, 2);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 163, 81}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MaxPool, NaiveCPU) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::UInt32);
|
||||||
|
auto op = g->addOp<MaxPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2);
|
||||||
|
|
||||||
|
g->dataMalloc();
|
||||||
|
i->setData(IncrementalGenerator());
|
||||||
|
cpuRuntime->run(g, true, true);
|
||||||
|
double perfTime = cpuRuntime->getPerfTime(g);
|
||||||
|
// The example matmul takes 0.0036ms with one core
|
||||||
|
EXPECT_GT(perfTime, 0);
|
||||||
|
EXPECT_LT(perfTime, 5);
|
||||||
|
// check answer
|
||||||
|
vector<uint32_t> ans = {6, 8, 9, 16, 18, 19, 21, 23, 24,
|
||||||
|
31, 33, 34, 41, 43, 44, 46, 48, 49};
|
||||||
|
EXPECT_TRUE(op->getOutput()->equalData(ans));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(AvgPool, NaiveCPU) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::Float32);
|
||||||
|
auto op = g->addOp<AvgPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2);
|
||||||
|
|
||||||
|
g->dataMalloc();
|
||||||
|
i->setData(IncrementalGenerator());
|
||||||
|
cpuRuntime->run(g, true, true);
|
||||||
|
|
||||||
|
// check answer
|
||||||
|
vector<float> ans = {
|
||||||
|
1.33333337, 3.0000, 2.66666675, 7.0000, 12.0000, 9.0000,
|
||||||
|
8.0000, 13.0000, 9.33333302, 12.444447, 19.666666, 13.7777777,
|
||||||
|
23.666666, 37.0000, 25.666666, 19.1111107, 29.666666, 20.4444447};
|
||||||
|
EXPECT_TRUE(op->getOutput()->equalData(ans));
|
||||||
|
|
||||||
|
double perfTime = cpuRuntime->getPerfTime(g);
|
||||||
|
// The example matmul takes 0.0036ms with one core
|
||||||
|
EXPECT_GT(perfTime, 0);
|
||||||
|
EXPECT_LT(perfTime, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void testPoolCudnn(
|
||||||
|
const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
|
const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) {
|
||||||
|
EXPECT_TRUE(kdps.size() == 8);
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor i0cpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||||
|
i0cpu->dataMalloc();
|
||||||
|
i0cpu->setData(generator);
|
||||||
|
|
||||||
|
// Build CUDA graph
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto i0 = g->cloneTensor(i0cpu);
|
||||||
|
auto pool = g->addOp<T>(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3],
|
||||||
|
kdps[4], kdps[5], kdps[6], kdps[7]);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
g->dataMalloc();
|
||||||
|
|
||||||
|
// Execute on CUDA
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
|
||||||
|
// clone CUDA output to CPU
|
||||||
|
auto o0 = pool->getOutput();
|
||||||
|
auto cpuo0 = o0->clone(cpuRuntime);
|
||||||
|
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(cpuo0->equalData(ansVec));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MaxPool, CuDNN) {
|
||||||
|
testPoolCudnn<MaxPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5},
|
||||||
|
KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||||
|
ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31,
|
||||||
|
33, 34, 41, 43, 44, 46, 48, 49});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(AvgPool, CuDNN) {
|
||||||
|
testPoolCudnn<AvgPoolObj>(
|
||||||
|
IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||||
|
ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000,
|
||||||
|
8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778,
|
||||||
|
23.666667, 37.0000, 25.666667, 19.111111, 29.666667,
|
||||||
|
20.444444});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue