diff --git a/CMakeLists.txt b/CMakeLists.txt index 8220b3ea..251ad0b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion diff --git a/include/core/kernel.h b/include/core/kernel.h index 4d1a97d5..52a10344 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -64,7 +64,9 @@ class KernelRegistry { return true; } Kernel *getKernel(const KernelAttrs &kernelAttrs) const { - return std::get<0>(kernels.at(kernelAttrs)); + auto it = kernels.find(kernelAttrs); + IT_ASSERT(it != kernels.end(), "Kernel not found."); + return std::get<0>(it->second); } const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { return kernels.at(kernelAttrs); diff --git a/include/core/operator.h b/include/core/operator.h index c9523bf8..0575f705 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -169,13 +169,14 @@ class OperatorObj : public Object { const TensorVec &getInputs() const { return inputs; } // TensorVec getOutputs() { return outputs; } const TensorVec &getOutputs() const { return outputs; } - Tensor getInputs(size_t i) { return inputs.at(i); } + Tensor getInputs(size_t i) const { return inputs.at(i); } Tensor getOutput() const { IT_ASSERT(outputs.size() == 1, "Unimplemented"); return outputs[0]; } OpType getOpType() const { return type; } - + // HACK: set correct data type + DataType getDType() const { return getInputs(0)->getDType(); } virtual int numInputs() const = 0; virtual int numOutputs() const = 0; diff --git a/include/core/tensor.h b/include/core/tensor.h index 05925ec7..41de4168 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -49,8 +49,18 @@ class TensorObj : public TensorBaseObj { void copyData(const Tensor &src) { copyData(src.get()); } void setData( const std::function &generator) const { + IT_ASSERT(data != nullptr); + if (!runtime->isCpu()) { + IT_TODO_HALT(); + } generator(data->getPtr(), size(), dtype); } + Tensor clone(Runtime runtime) { + auto obj = make_ref(shape, dtype, runtime); + obj->dataMalloc(); + obj->copyData(this); + return obj; + } void printData() const; bool equalData(const Tensor &rhs) const; diff --git a/include/operators/pooling.h b/include/operators/pooling.h new file mode 100644 index 00000000..b0336b46 --- /dev/null +++ b/include/operators/pooling.h @@ -0,0 +1,54 @@ +#pragma once +#include "core/operator.h" + +namespace infini { + +class PoolingObj : public OperatorObj { + private: + int kh, kw; + int dh, dw; + int ph, pw; + int sh, sw; + int n, c, h, w; + + public: + PoolingObj(GraphObj *graph, OpType optype, Tensor input, Tensor output, + int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw); + + optional> inferShape(const TensorVec &inputs) const override; + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + + int getKh() const { return kh; } + int getKw() const { return kw; } + int getDh() const { return dh; } + int getDw() const { return dw; } + int getPh() const { return ph; } + int getPw() const { return pw; } + int getSh() const { return sh; } + int getSw() const { return sw; } + + auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); } + auto getNCHWRS() const { return tuple(n, c, h, w, kh, kw); } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +class MaxPoolObj : public PoolingObj { + public: + MaxPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw, + int dh, int dw, int ph, int pw, int sh, int sw) + : PoolingObj(graph, OpType::MaxPool, input, output, kh, kw, dh, dw, ph, + pw, sh, sw) {} +}; +class AvgPoolObj : public PoolingObj { + public: + AvgPoolObj(GraphObj *graph, Tensor input, Tensor output, int kh, int kw, + int dh, int dw, int ph, int pw, int sh, int sw) + : PoolingObj(graph, OpType::AvgPool, input, output, kh, kw, dh, dw, ph, + pw, sh, sw) {} +}; +}; // namespace infini \ No newline at end of file diff --git a/src/core/operator.cc b/src/core/operator.cc index ec7df9b1..b6617d66 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -62,7 +62,7 @@ bool OperatorObj::checkValid(GraphObj *graph) { IT_ASSERT(!outputs[i]); outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); } - } else { // if graph is not empty, check outputs match inferred shapes + } else { // if outputs have been created, check their shapes for (size_t i = 0; i < shapes.size(); ++i) { if (shapes[i] != outputs[i]->getDims()) return false; @@ -86,4 +86,4 @@ vector OperatorObj::inferDataType() const { return inferDataType(inputs); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 93bd1ea7..a256b285 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -22,9 +22,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { std::map opCnt; for (auto &op : graph->getOperators()) { - // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType(), DataType::UInt32}; + auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; std::optional perfData = perfEngine.getPerfData(perfKey); @@ -72,9 +70,7 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const { std::map opCnt; for (auto &op : graph->getOperators()) { - // HACK: set correct data type - auto kernelAttrs = - KernelAttrs{device, op->getOpType(), DataType::UInt32}; + auto kernelAttrs = KernelAttrs{device, op->getOpType(), op->getDType()}; Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; std::optional perfData = perfEngine.getPerfData(perfKey); @@ -146,4 +142,4 @@ void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, void *src, memcpy(dst, src, bytes); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/kernels/cpu/pooling.cc b/src/kernels/cpu/pooling.cc new file mode 100644 index 00000000..585f6574 --- /dev/null +++ b/src/kernels/cpu/pooling.cc @@ -0,0 +1,91 @@ +#include "operators/pooling.h" +#include "core/kernel.h" + +namespace infini { +template class NativePooling : public Kernel { + virtual T getPoolingValue(int kh, int kw, int posh, int posw, int ih, + int iw, T *inptr) const = 0; + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *context) const override { + auto op = as(_op); + T *inptr = op->getInputs(0)->getRawDataPtr(); + T *outptr = op->getOutput()->getRawDataPtr(); + const auto [n, c, ih, iw, kh, kw] = op->getNCHWRS(); + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + if (dh != 1 || dw != 1) + IT_TODO_HALT(); // To support dailated pooling + auto outDim = op->getOutput()->getDims(); + int oh = outDim[2], ow = outDim[3]; + for (auto i = 0; i < n; i++) { + for (auto j = 0; j < c; j++) { + auto inoffset = i * (c * ih * iw) + j * ih * iw; + for (auto h = 0; h < oh; h++) { + for (auto w = 0; w < ow; w++) { + T val = + getPoolingValue(kh, kw, h * sh - ph, w * sw - pw, + ih, iw, inptr + inoffset); + auto outoffset = + w + h * ow + j * (oh * ow) + i * (c * oh * ow); + outptr[outoffset] = val; + } + } + } + } + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + compute(op, {}, context); + } + + PerfRecord tune(const Operator &op, + const RuntimeObj *context) const override { + PerfRecord perfrcd(timeit([&]() { compute(op, context); })); + return perfrcd; + } +}; + +template class NaiveMaxPool : public NativePooling { + T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw, + T *inptr) const override { + T maxval = 0; + for (auto k = 0; k < kh; k++) { + for (auto l = 0; l < kw; l++) { + auto inPosH = posh + k; + auto inPosW = posw + l; + if (inPosH < 0 || inPosH >= ih || inPosW < 0 || inPosW >= iw) + continue; + auto offset = (posh + k) * iw + posw + l; + auto val = inptr[offset]; + if (maxval < val) + maxval = val; + } + } + return maxval; + } +}; + +template class NaiveAvgPool : public NativePooling { + T getPoolingValue(int kh, int kw, int posh, int posw, int ih, int iw, + T *inptr) const override { + T sum = 0; + for (auto k = 0; k < kh; k++) { + for (auto l = 0; l < kw; l++) { + auto inPosH = posh + k; + auto inPosW = posw + l; + if (inPosH < 0 || inPosH >= ih || inPosW < 0 || inPosW >= iw) + continue; + auto offset = (posh + k) * iw + posw + l; + sum += inptr[offset]; + } + } + return T(sum / (kh * kw)); + } +}; + +REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::UInt32, + NaiveMaxPool, "maxPoolNaive_CPU_uint32"); +REGISTER_KERNEL(Device::CPU, OpType::MaxPool, DataType::Float32, + NaiveMaxPool, "maxPoolNaive_CPU_float32"); +REGISTER_KERNEL(Device::CPU, OpType::AvgPool, DataType::Float32, + NaiveAvgPool, "AvgPoolNaive_CPU_float32"); +} // namespace infini \ No newline at end of file diff --git a/src/kernels/cuda/pooling.cc b/src/kernels/cuda/pooling.cc new file mode 100644 index 00000000..a551ef76 --- /dev/null +++ b/src/kernels/cuda/pooling.cc @@ -0,0 +1,89 @@ +#include "operators/pooling.h" +#include "core/kernel.h" +#include "cuda/cuda_runtime.h" + +namespace infini { +class poolingCudnn : public Kernel { + virtual cudnnPoolingMode_t getPoolingMode() const = 0; + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + cudnnStatus_t stat; + void *const inData = (op->getInputs(0)->getRawDataPtr()); + void *const outData = (op->getOutput()->getRawDataPtr()); + + const auto [n, c, h, w, kh, kw] = op->getNCHWRS(); + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + + // get inputs + cudnnTensorDescriptor_t inDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w)); + + // get maxpool descriptor + cudnnPoolingDescriptor_t poolingDesc; + checkCudnnError(cudnnCreatePoolingDescriptor(&poolingDesc)); + checkCudnnError(cudnnSetPooling2dDescriptor( + poolingDesc, getPoolingMode(), CUDNN_NOT_PROPAGATE_NAN, kh, kw, ph, + pw, sh, sw)); + + // get outputs + int outn, outc, outh, outw; + checkCudnnError(cudnnGetPooling2dForwardOutputDim( + poolingDesc, inDesc, &outn, &outc, &outh, &outw)); + cudnnTensorDescriptor_t outDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW, + CUDNN_DATA_FLOAT, outn, outc, + outh, outw)); + IT_ASSERT((vector{outn, outc, outh, outw}) == + op->getOutput()->getDims(), + "cuDNN output shape mismatches with OP output shape"); + + float alpha = 1.f, beta = 0.f; + stat = cudnnPoolingForward(context->cudnnHandle(), poolingDesc, &alpha, + inDesc, inData, &beta, outDesc, outData); + if (stat != CUDNN_STATUS_SUCCESS) + return; + + // Destories in CUDA does not require sync. But cuDNN does not state + // whether sync is required before destories. + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); + checkCudnnError(cudnnDestroyPoolingDescriptor(poolingDesc)); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + compute(_op, {}, _context); + } + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + PerfRecord ret; + auto context = dynamic_cast(_context); + ret.time = timeit([&]() { compute(_op, _context); }, + [&]() { context->sync(); }); + return ret; + } +}; + +class maxPoolCudnn : public poolingCudnn { + cudnnPoolingMode_t getPoolingMode() const override { + return CUDNN_POOLING_MAX; + } +}; + +class avgPoolCudnn : public poolingCudnn { + cudnnPoolingMode_t getPoolingMode() const override { + return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::MaxPool, DataType::Float32, maxPoolCudnn, + "MaxPool_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::AvgPool, DataType::Float32, avgPoolCudnn, + "AvgPool_cuDNN_CUDA_Float32"); +}; // namespace infini \ No newline at end of file diff --git a/src/operators/pooling.cc b/src/operators/pooling.cc new file mode 100644 index 00000000..0fcc5416 --- /dev/null +++ b/src/operators/pooling.cc @@ -0,0 +1,52 @@ +#include "operators/pooling.h" + +namespace infini { + +PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input, + Tensor output, int kh, int kw, int dh, int dw, int ph, + int pw, int sh, int sw) + : OperatorObj(optype, {input}, {output}), kh(kh), kw(kw), dh(dh), dw(dw), + ph(ph), pw(pw), sh(sh), sw(sw) { + n = input->getDims()[0]; + c = input->getDims()[1]; + h = input->getDims()[2], w = input->getDims()[3]; + + IT_ASSERT(checkValid(graph)); +} + +optional> PoolingObj::inferShape(const TensorVec &inputs) const { + const auto &input = inputs[0]; + auto h = input->getDims()[input->getDims().size() - 2], + w = input->getDims()[input->getDims().size() - 1]; + int oh = (h - (kh - sh) + ph * 2) / sh; + int ow = (w - (kw - sw) + pw * 2) / sw; + auto ret = input->getDims(); + ret[input->getDims().size() - 2] = oh; + ret[input->getDims().size() - 1] = ow; + return {{ret}}; +} + +std::string PoolingObj::toString() const { + std::ostringstream os; + os << "Maxpool[" << getGuid() << "]"; + os << "("; + os << "k=[" << kh << "," << kw << "],"; + os << "p=[" << ph << "," << pw << "],"; + os << "s=[" << sh << "," << sw << "],"; + os << "d=[" << dh << "," << dw << "],"; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector PoolingObj::getWorkloadVector() const { + return { + enum_to_underlying(type), n, c, h, w, kh, kw, ph, pw, sh, sw, dh, dw}; +} + +vector PoolingObj::getOpAttrVector() const { + IT_TODO_HALT(); + return {enum_to_underlying(type), kh, kw, ph, pw, sh, sw, dh, dw}; +} + +}; // namespace infini \ No newline at end of file diff --git a/test/core/test_hash.cc b/test/core/test_hash.cc index 3c244acb..8c1e659a 100644 --- a/test/core/test_hash.cc +++ b/test/core/test_hash.cc @@ -14,8 +14,8 @@ TEST(Hash, OperatorHash) { Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); auto matmul = g->addOpWithOutputs(i0, w0, o0); key1 = matmul->getOpPerfKey(); - EXPECT_NE(key1.hash, 0); - EXPECT_GT(key1.attrs.size(), 5); + EXPECT_NE(key1.hash, (HashType)0); + EXPECT_GT(key1.attrs.size(), (size_t)5); } { // build with addOp Graph g = make_ref(nullptr); @@ -23,7 +23,7 @@ TEST(Hash, OperatorHash) { Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32); auto matmul = g->addOp(i0, w0, nullptr); key2 = matmul->getOpPerfKey(); - EXPECT_NE(key2.hash, 0); + EXPECT_NE(key2.hash, (HashType)0); } EXPECT_NE(key1.hash, key2.hash); } diff --git a/test/operators/test_pooling.cc b/test/operators/test_pooling.cc new file mode 100644 index 00000000..8d1c0c00 --- /dev/null +++ b/test/operators/test_pooling.cc @@ -0,0 +1,122 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/pooling.h" +#include "test.h" + +namespace infini { +using KDPS = vector; +using ExpectOutput = vector; +TEST(MaxPool, ShapeInference) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + { + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32); + const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2, + sw = 2; + auto op = + g->addOp(i, nullptr, kh, kw, dh, dw, ph, pw, sh, sw); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 80, 80})); + } + + { // dilation & stride + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32); + auto op = g->addOp(i, nullptr, 4, 3, 1, 1, 2, 1, 1, 2); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 64, 163, 81})); + } +} + +TEST(MaxPool, NaiveCPU) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 2, 5, 5}, DataType::UInt32); + auto op = g->addOp(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2); + + g->dataMalloc(); + i->setData(IncrementalGenerator()); + cpuRuntime->run(g, true, true); + double perfTime = cpuRuntime->getPerfTime(g); + // The example matmul takes 0.0036ms with one core + EXPECT_GT(perfTime, 0); + EXPECT_LT(perfTime, 5); + // check answer + vector ans = {6, 8, 9, 16, 18, 19, 21, 23, 24, + 31, 33, 34, 41, 43, 44, 46, 48, 49}; + EXPECT_TRUE(op->getOutput()->equalData(ans)); +} + +TEST(AvgPool, NaiveCPU) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(cpuRuntime); + Tensor i = g->addTensor({1, 2, 5, 5}, DataType::Float32); + auto op = g->addOp(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2); + + g->dataMalloc(); + i->setData(IncrementalGenerator()); + cpuRuntime->run(g, true, true); + + // check answer + vector ans = { + 1.33333337, 3.0000, 2.66666675, 7.0000, 12.0000, 9.0000, + 8.0000, 13.0000, 9.33333302, 12.444447, 19.666666, 13.7777777, + 23.666666, 37.0000, 25.666666, 19.1111107, 29.666666, 20.4444447}; + EXPECT_TRUE(op->getOutput()->equalData(ans)); + + double perfTime = cpuRuntime->getPerfTime(g); + // The example matmul takes 0.0036ms with one core + EXPECT_GT(perfTime, 0); + EXPECT_LT(perfTime, 5); +} + +template +void testPoolCudnn( + const std::function &generator, + const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) { + EXPECT_TRUE(kdps.size() == 8); + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + auto cudaRuntime = make_ref(); + + // Build input data on CPU + Tensor i0cpu = make_ref(shape, DataType::Float32, cpuRuntime); + i0cpu->dataMalloc(); + i0cpu->setData(generator); + + // Build CUDA graph + Graph g = make_ref(cudaRuntime); + auto i0 = g->cloneTensor(i0cpu); + auto pool = g->addOp(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3], + kdps[4], kdps[5], kdps[6], kdps[7]); + + // allocate CUDA memory + g->dataMalloc(); + + // Execute on CUDA + cudaRuntime->run(g); + + // clone CUDA output to CPU + auto o0 = pool->getOutput(); + auto cpuo0 = o0->clone(cpuRuntime); + + // check results on CPU + EXPECT_TRUE(cpuo0->equalData(ansVec)); +} + +TEST(MaxPool, CuDNN) { + testPoolCudnn(IncrementalGenerator(), Shape{1, 2, 5, 5}, + KDPS{3, 3, 1, 1, 1, 1, 2, 2}, + ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31, + 33, 34, 41, 43, 44, 46, 48, 49}); +} + +TEST(AvgPool, CuDNN) { + testPoolCudnn( + IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2}, + ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000, + 8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778, + 23.666667, 37.0000, 25.666667, 19.111111, 29.666667, + 20.444444}); +} + +} // namespace infini \ No newline at end of file