From 8f67a5cc762e553e7d54f6161c74e6c4311fdc14 Mon Sep 17 00:00:00 2001 From: zhengly123 Date: Mon, 19 Sep 2022 15:05:39 +0800 Subject: [PATCH] Add: ConvTransposed (#33) * Add: convTransposed2d operator * Fix: IT_ASSERT namespace * Add: nullptr check in as for Ref * Fix: conv transpose operator and kernel * Fix: makes PerfEngine singleton * Add: ConvTransposed test * Fix: rebase to master (PerfRecord shared_ptr) * Revert: Ref with nullptr check Co-authored-by: Liyan Zheng --- include/core/common.h | 2 +- include/core/perf_engine.h | 4 + include/core/ref.h | 1 + include/operators/conv.h | 84 +++++-- src/core/runtime.cc | 4 +- src/cuda/cuda_runtime.cc | 2 +- src/kernels/cuda/conv_transposed.cc | 289 ++++++++++++++++++++++ src/operators/conv.cc | 147 ++++++++--- test/operators/test_conv_transposed_2d.cc | 115 +++++++++ 9 files changed, 594 insertions(+), 54 deletions(-) create mode 100644 src/kernels/cuda/conv_transposed.cc create mode 100644 test/operators/test_conv_transposed_2d.cc diff --git a/include/core/common.h b/include/core/common.h index 600b44ef..6bdb92a3 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -42,7 +42,7 @@ using HashType = uint64_t; // compatible with std::hash #define _IT_ASSERT_2(name, info) \ (static_cast(name) \ ? void(0) \ - : throw infini::Exception( \ + : throw ::infini::Exception( \ std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ "] Assertion failed (" + #name + "): " + #info)) #define _IT_ASSERT_1(name) _IT_ASSERT_2(name, ""); diff --git a/include/core/perf_engine.h b/include/core/perf_engine.h index 563ad704..4689cbcf 100644 --- a/include/core/perf_engine.h +++ b/include/core/perf_engine.h @@ -9,6 +9,10 @@ class PerfEngine { // TODO: Key should be OpPerfKey + Context(maybe implicat) to support // multiple candiate kernels. using Key = std::pair; + PerfEngine() = default; + // PerfEngine is singleton + PerfEngine(PerfEngine &other) = delete; + PerfEngine &operator=(PerfEngine const &) = delete; private: map data; diff --git a/include/core/ref.h b/include/core/ref.h index f5ba4e89..76357818 100644 --- a/include/core/ref.h +++ b/include/core/ref.h @@ -1,4 +1,5 @@ #pragma once +#include "core/common.h" #include // hash #include #include diff --git a/include/operators/conv.h b/include/operators/conv.h index 841d1351..95bbf0bf 100644 --- a/include/operators/conv.h +++ b/include/operators/conv.h @@ -3,7 +3,7 @@ namespace infini { -class ConvObj : public OperatorObj { +class ConvBaseObj : public OperatorObj { public: // When PaddingMode is Other, ConvObj will use padding size (ph, pw) // Otherwise, padding size (ph, pw) will be computed by padding mode @@ -13,34 +13,33 @@ class ConvObj : public OperatorObj { Valid, }; - private: + protected: int ph, pw; int sh, sw; int dh, dw; - ActType act; PaddingMode padding; - // auxiliary attributes - int n, c, h, w, f, r, s; + // auxiliary attributes. Descripitions stand on a forward perspective, + // i.e., convTransposed2d is not regarded as the backward of conv2d. + int n; // batch size + int c; // input/output channel for conv2d/convTransposed2d + int h, w; // input shape (same for conv2d and convTranposed2d) + int f; // output/input channel for conv2d/convTransposed2d + int r, s; // weight shape public: // Constructors for explicitly setting padding size - ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph, - int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1, - Tensor bias = nullptr, ActType act = ActType::None); - // Constructors for setting padding mode - ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, - PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1, - int dh = 1, int dw = 1, Tensor bias = nullptr, - ActType act = ActType::None); - - optional> inferShape(const TensorVec &inputs) const override; + ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, int ph, int pw, + int sh, int sw, int dh, int dw, const Tensor &inputInConvFWD, + const Tensor &weightInConvFWD); + ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, + PaddingMode mode, int sh, int sw, int dh, int dw, + const Tensor &inputInConvFWD, const Tensor &weightInConvFWD); std::string toString() const override; int numInputs() const override { return 2; } int numOutputs() const override { return 1; } Tensor getBias() const { return inputs[2]; } - ActType getAct() const { return act; } PaddingMode getPaddingMode() const { return padding; } pair inferPaddingSize() const; @@ -53,7 +52,7 @@ class ConvObj : public OperatorObj { auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); } auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); } int getChannelPerGroup() const { return inputs[1]->getDims()[1]; } - int getNumGroups() const { return c / getChannelPerGroup(); } + virtual int getNumGroups() const = 0; private: vector getWorkloadVector() const override; @@ -62,7 +61,56 @@ class ConvObj : public OperatorObj { * @brief Set the Auxilary Attributes: nchwrfs and padding (ph, pw) if * padding mode is set. This function should be called in constructor. */ - void setAuxilaryAttributes(PaddingMode mode); + virtual void setAuxilaryAttributes(PaddingMode mode) = 0; +}; + +class ConvObj : public ConvBaseObj { + private: + ActType act; + + public: + ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph, + int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1, + Tensor bias = nullptr, ActType act = ActType::None); + // Constructors for setting padding mode + ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1, + int dh = 1, int dw = 1, Tensor bias = nullptr, + ActType act = ActType::None); + + optional> inferShape(const TensorVec &inputs) const override; + ActType getAct() const { return act; } + int getNumGroups() const override { return c / getChannelPerGroup(); } + + private: + void setAuxilaryAttributes(PaddingMode mode) override; +}; + +class ConvTransposed2dObj : public ConvBaseObj { + private: + int oph, opw; + int group; + ActType act; + + public: + ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight, + Tensor output, int ph, int pw, int sh = 1, int sw = 1, + int dh = 1, int dw = 1, int oph = 0, int opw = 0, + int group = 1, Tensor bias = nullptr, + ActType act = ActType::None); + // Constructors for setting padding mode + ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight, + Tensor output, PaddingMode mode = PaddingMode::Same, + int sh = 1, int sw = 1, int dh = 1, int dw = 1, + int oph = 0, int opw = 0, int group = 1, + Tensor bias = nullptr, ActType act = ActType::None); + + optional> inferShape(const TensorVec &inputs) const override; + ActType getAct() const { return act; } + int getNumGroups() const override { return group; } + + private: + void setAuxilaryAttributes(PaddingMode mode) override; }; } // namespace infini diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 1f32726b..449c997e 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -15,7 +15,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { if (!tune && profiling) IT_TODO_HALT(); const auto &kernelRegistry = KernelRegistry::getInstance(); - auto perfEngine = PerfEngine::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); // Statistics double totalTime = 0; std::map opTime; @@ -63,7 +63,7 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const { const auto &kernelRegistry = KernelRegistry::getInstance(); - auto perfEngine = PerfEngine::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); // Statistics double totalTime = 0; std::map opTime; diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 1212c923..fbfdbbbe 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -7,7 +7,7 @@ namespace infini { void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, bool profiling = false) const { const auto &kernelRegistry = KernelRegistry::getInstance(); - auto perfEngine = PerfEngine::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); double totalTime = 0; std::map opTime; std::map opCnt; diff --git a/src/kernels/cuda/conv_transposed.cc b/src/kernels/cuda/conv_transposed.cc new file mode 100644 index 00000000..6f379006 --- /dev/null +++ b/src/kernels/cuda/conv_transposed.cc @@ -0,0 +1,289 @@ +#include "core/kernel.h" +#include "cuda/cuda_runtime.h" +#include "operators/conv.h" +#include +#include +#include +#include +namespace infini { + +struct ConvTransposedCuDnnPerfRecordObj : public PerfRecordObj { + int algo = 0; // cudnnConvolutionBwdDataAlgo_t + int mode = 1; + size_t workspaceSize = 100000; + bool fuseAct = false; +}; +using ConvTransposedCuDnnPerfRecord = Ref; + +static constexpr int N_ALGO = 6; +static_assert(N_ALGO == int(CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT), + "Unsupported cuDNN version"); +static const cudnnConvolutionBwdDataAlgo_t ALGOS[N_ALGO] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, /* non-deterministic */ + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; +static const char algo_name[N_ALGO][50] = { + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", /* non-deterministic */ + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD", + "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"}; +static const char math_types[3][50] = {"CUDNN_DEFAULT_MATH", + "CUDNN_TENSOR_OP_MATH", + "CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION"}; +static constexpr int N_MODE = 2; +static constexpr cudnnConvolutionMode_t MODES[N_MODE] = { + CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION}; + +class convBackwardDataCudnn : public Kernel { + + std::tuple + createCuDNNDescriptor( + const Ref &op, + const ConvTransposedCuDnnPerfRecordObj &record) const { + void *const inData = (op->getInputs(0)->getRawDataPtr()); + void *const knData = (op->getInputs(1)->getRawDataPtr()); + if (op->getInputs().size() > 2) // Bias is not supported yet + IT_TODO_HALT(); + // void *const biasData = (op->getInputs(2)->getRawDataPtr()); + void *const outData = (op->getOutput()->getRawDataPtr()); + + const auto [n, c, h, w, f, r, s] = op->getNCHWFRS(); + const int channelsPerGrp = op->getChannelPerGroup(); + const int g = op->getNumGroups(); + const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + // IT_ASSERT(g == 1, "Group convolution is not supported yet"); + + // get inputs + cudnnTensorDescriptor_t inDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, f, h, w)); + + // get kernels + cudnnFilterDescriptor_t knDesc; + checkCudnnError(cudnnCreateFilterDescriptor(&knDesc)); + checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT, + CUDNN_TENSOR_NCHW, f, + channelsPerGrp, r, s)); + // get bias + cudnnTensorDescriptor_t biasDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1)); + + // get convlution descriptor + cudnnConvolutionDescriptor_t convDesc; + checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc)); + // TODO: CUDNN_CONVOLUTION is a tunable argument + checkCudnnError(cudnnSetConvolution2dDescriptor( + convDesc, ph, pw, sh, sw, dh, dw, MODES[record.mode], + CUDNN_DATA_FLOAT)); + if (g > 1) { + checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g)); + } + + // get activation descriptor + cudnnActivationDescriptor_t actDesc; + checkCudnnError(cudnnCreateActivationDescriptor(&actDesc)); + // NOT_PROPAGATE_NAN is requierd by + // cudnnConvolotionBiasActivationForward + switch (op->getAct()) { + case ActType::Relu: + checkCudnnError(cudnnSetActivationDescriptor( + actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0)); + break; + case ActType::Sigmoid: + checkCudnnError(cudnnSetActivationDescriptor( + actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0)); + break; + case ActType::None: + checkCudnnError( + cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY, + CUDNN_NOT_PROPAGATE_NAN, 0)); + break; + default: + assert(false); + } + + const auto &outputShape = op->getOutput()->getDims(); + cudnnTensorDescriptor_t outDesc; + checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); + checkCudnnError(cudnnSetTensor4dDescriptor( + outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, outputShape[0], + outputShape[1], outputShape[2], outputShape[3])); + return tuple(inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc); + } + + bool cuDNNUnfused(const Ref &op, + const ConvTransposedCuDnnPerfRecordObj &record, + const CudaRuntimeObj *context) const { + cudnnStatus_t stat; + + const auto &[inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc] = + createCuDNNDescriptor(op, record); + size_t wsSize = record.workspaceSize; + CudaPtr wsData = context->getWorkspace(wsSize); + float alpha = 1.f, beta = 0.f; + + stat = cudnnConvolutionBackwardData( + context->cudnnHandle(), &alpha, knDesc, knData, inDesc, inData, + convDesc, ALGOS[record.algo], wsData, wsSize, &beta, outDesc, + outData); + if (stat != CUDNN_STATUS_SUCCESS) + return false; + // TODO: + // // bias + // if (bias != nullptr) { + // auto sz = op.getOutputs()[0]->size(); + // // TODO: element wise + // t += sz * 2 / 400; + // } + // // act + // if (act != None) { + // stat = cudnnActivationForward(cudnnHandle(), actDesc, + // &alpha, inDesc, inData, + // &beta, outDesc, outData); + // checkCudaError(cudaDeviceSynchronize()); + // end = ch::high_resolution_clock::now(); + // if (stat != CUDNN_STATUS_SUCCESS) { + // durtime = INFINITY; + // break; + // } + // t += + // ch::duration_cast>(end - + // beg).count() * 1000; // ms + // } + + // best = ConvResult{durtime, ALGOS[i], wsSize, false}; + + // // w/ bias & act + // for (int j = 0; j < rounds + warmupRounds; ++j) { + // cudnnStatus_t stat; + // if (j == warmupRounds) { + // checkCudaError(cudaDeviceSynchronize()); + // beg = ch::high_resolution_clock::now(); + // } + // stat = cudnnConvolutionBiasActivationForward( + // cudnnHandle(), &alpha, inDesc, inData, knDesc, + // knData, convDesc, ALGOS[i], wsData, wsSize, &beta, + // outDesc, outData, biasDesc, biasData, actDesc, + // outDesc, outData); + // if (stat != CUDNN_STATUS_SUCCESS) { + // // checkCudnnError(stat); + // // Do not checkCudnnError since not all algorithms + // are + // // supported + // durtime_fuse = INFINITY; + // break; + // } + // } + + // Destories in CUDA does not require sync. But cuDNN does not + // state whether sync is required before destories. + checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); + checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); + checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); + checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + return true; + } + + void compute(const Operator &op, const RuntimeObj *context) const override { + // with paramters in default ctor + auto record = make_ref(); + compute(op, record, context); + } + + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + ConvTransposedCuDnnPerfRecordObj ret; + ret.time = std::numeric_limits::max(); + auto context = dynamic_cast(_context); + auto op = as(_op); + // Both modes have the same performance. Only run + // cross-correlation. + for (int mode = 1; mode < 2; mode++) { + // Try every possible algorithm of convolution + for (int algo = 0; algo < N_ALGO; algo++) { + ConvTransposedCuDnnPerfRecordObj record; + record.mode = mode; + record.algo = algo; + cudnnStatus_t stat; + const auto &[inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc] = + createCuDNNDescriptor(op, record); + + // get workspace + stat = cudnnGetConvolutionBackwardDataWorkspaceSize( + context->cudnnHandle(), knDesc, inDesc, convDesc, outDesc, + ALGOS[record.algo], &record.workspaceSize); + if (stat != CUDNN_STATUS_SUCCESS) + continue; + + CudaPtr wsData = context->getWorkspace(record.workspaceSize); + float alpha = 1.f, beta = 0.f; + + stat = cudnnConvolutionBackwardData( + context->cudnnHandle(), &alpha, knDesc, knData, inDesc, + inData, convDesc, ALGOS[record.algo], wsData, + record.workspaceSize, &beta, outDesc, outData); + if (stat != CUDNN_STATUS_SUCCESS) + continue; + record.time = timeit( + [&]() { + cudnnConvolutionBackwardData( + context->cudnnHandle(), &alpha, knDesc, knData, + inDesc, inData, convDesc, ALGOS[record.algo], + wsData, record.workspaceSize, &beta, outDesc, + outData); + }, + [&]() { context->sync(); }); + // printf("mode:%d algo:%d :%.8lf\n", mode, algo, + // record.time); + + // Update the tune result + if (ret.time > record.time) + ret = record; + checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); + checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); + checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); + checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + } + } + // printf("the best algo is %d, the best conv mode is %d\n", + // ret.algo, + // ret.mode); + IT_ASSERT(ret.time < std::numeric_limits::max(), "No valid " + "algorithm " + "found"); + return make_ref(ret); + } + + void compute(const Operator &_op, const PerfRecord &_record, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto record = as(_record); + auto context = dynamic_cast(_context); + bool success = cuDNNUnfused(op, *record, context); + IT_ASSERT(success); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::ConvTrans, DataType::Float32, + convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32"); + +} // namespace infini \ No newline at end of file diff --git a/src/operators/conv.cc b/src/operators/conv.cc index 8c63f241..e8571f06 100644 --- a/src/operators/conv.cc +++ b/src/operators/conv.cc @@ -2,28 +2,24 @@ namespace infini { -ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, - int ph, int pw, int sh, int sw, int dh, int dw, - [[maybe_unused]] Tensor bias, ActType act) - : OperatorObj(OpType::Conv, {input, weight}, {output}), ph(ph), pw(pw), - sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(PaddingMode::Other) { - setAuxilaryAttributes(PaddingMode::Other); - IT_ASSERT(checkValid(graph)); -} - -ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, - PaddingMode mode, int sh, int sw, int dh, int dw, - [[maybe_unused]] Tensor bias, ActType act) - : OperatorObj(OpType::Conv, {input, weight}, {output}), ph(-1), pw(-1), - sh(sh), sw(sw), dh(dh), dw(dw), act(act), padding(mode) { +ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, + int ph, int pw, int sh, int sw, int dh, int dw, + const Tensor &inputInConvFWD, + const Tensor &weightInConvFWD) + : OperatorObj(opType, inputs, {output}), ph(ph), pw(pw), sh(sh), sw(sw), + dh(dh), dw(dw), padding(PaddingMode::Other) {} +ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, + PaddingMode mode, int sh, int sw, int dh, int dw, + const Tensor &inputInConvFWD, + const Tensor &weightInConvFWD) + : OperatorObj(opType, inputs, {output}), ph(-1), pw(-1), sh(sh), sw(sw), + dh(dh), dw(dw), padding(mode) { IT_ASSERT(mode != PaddingMode::Other); - setAuxilaryAttributes(mode); - IT_ASSERT(checkValid(graph)); } -string ConvObj::toString() const { +string ConvBaseObj::toString() const { std::ostringstream os; - os << "Conv[" << getGuid() << "]"; + os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]"; os << "("; if (inputs.size() == 2) { os << vecToString(inputs[0]->getDims()) << ","; @@ -32,13 +28,63 @@ string ConvObj::toString() const { os << "p=[" << ph << "," << pw << "],"; os << "s=[" << sh << "," << sw << "],"; os << "d=[" << dh << "," << dw << "],"; - os << "act=" << enum_to_underlying(act) << ","; + // os << "act=" << enum_to_underlying(act) << ","; os << "input=" << inputs[0]->getGuid() << ","; os << "weight=" << inputs[1]->getGuid() << ","; os << "output=" << outputs[0]->getGuid() << ")"; return os.str(); } +vector ConvBaseObj::getWorkloadVector() const { + return { + enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw}; +} + +vector ConvBaseObj::getOpAttrVector() const { + IT_TODO_HALT(); // should padding mode / ph+pw be in attrs? + return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw}; +} + +void ConvObj::setAuxilaryAttributes(PaddingMode mode) { + const Tensor &input = inputs[0]; + const Tensor &weight = inputs[1]; + n = input->getDims()[0], c = input->getDims()[1], h = input->getDims()[2], + w = input->getDims()[3], f = weight->getDims()[0], r = weight->getDims()[2], + s = weight->getDims()[3]; + if (mode == PaddingMode::Same) { + int oh = h / sh; + int ow = w / sw; + ph = (h - oh * sh + (r - sh) * dh) / 2; + pw = (w - ow * sw + (s - sw) * dw) / 2; + } else if (mode == PaddingMode::Valid) { + ph = pw = 0; + } +} + +ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias, + ActType act) + : ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw, + input, weight), + act(act) { + if (bias) + IT_TODO_HALT(); + setAuxilaryAttributes(PaddingMode::Other); + IT_ASSERT(checkValid(graph)); +} + +ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias, + ActType act) + : ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw, + input, weight), + act(act) { + if (bias) + IT_TODO_HALT(); + setAuxilaryAttributes(mode); + IT_ASSERT(checkValid(graph)); +} + optional> ConvObj::inferShape(const TensorVec &inputs) const { const auto &input = inputs[0], &weight = inputs[1]; auto n = input->getDims()[0]; @@ -70,23 +116,60 @@ optional> ConvObj::inferShape(const TensorVec &inputs) const { return {{{on, oc, oh, ow}}}; } -vector ConvObj::getWorkloadVector() const { - return { - enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, - enum_to_underlying(act)}; +ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input, + Tensor weight, Tensor output, int ph, + int pw, int sh, int sw, int dh, int dw, + int oph, int opw, int group, + Tensor bias, ActType act) + : ConvBaseObj(OpType::ConvTrans, {input, weight}, output, ph, pw, sh, sw, + dh, dw, output, weight), + oph(oph), opw(opw), group(group), act(act) { + if (bias) + IT_TODO_HALT(); + setAuxilaryAttributes(PaddingMode::Other); + IT_ASSERT(checkValid(graph)); } -vector ConvObj::getOpAttrVector() const { - IT_TODO_HALT(); // should padding mode / ph+pw be in attrs? - return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw, - enum_to_underlying(act)}; +ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input, + Tensor weight, Tensor output, + PaddingMode mode, int sh, int sw, + int dh, int dw, int oph, int opw, + int group, Tensor bias, ActType act) + : ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh, + dw, output, weight), + oph(oph), opw(opw), group(group), act(act) { + if (bias) + IT_TODO_HALT(); + setAuxilaryAttributes(mode); + IT_ASSERT(checkValid(graph)); } -void ConvObj::setAuxilaryAttributes(PaddingMode mode) { - n = inputs[0]->getDims()[0], c = inputs[0]->getDims()[1], - h = inputs[0]->getDims()[2], w = inputs[0]->getDims()[3], - f = inputs[1]->getDims()[0], r = inputs[1]->getDims()[2], - s = inputs[1]->getDims()[3]; +optional> +ConvTransposed2dObj::inferShape(const TensorVec &inputs) const { + const Tensor &input = inputs[0], &weight = inputs[1]; + auto n = input->getDims()[0]; + auto f = input->getDims()[1]; + auto h = input->getDims()[2]; + auto w = input->getDims()[3]; + auto c = weight->getDims()[1]; + auto r = weight->getDims()[2]; + auto s = weight->getDims()[3]; + if (f != weight->getDims()[0]) + return {}; + + int on = n, oc = c * group; + int oh = 0, ow = 0; + oh = (h - 1) * sh - 2 * ph + dh * (r - 1) + oph + 1; + ow = (w - 1) * sw - 2 * pw + dw * (s - 1) + opw + 1; + return {{{on, oc, oh, ow}}}; +} + +void ConvTransposed2dObj::setAuxilaryAttributes(PaddingMode mode) { + const Tensor &input = inputs[0]; + const Tensor &weight = inputs[1]; + n = input->getDims()[0], f = input->getDims()[1], h = input->getDims()[2], + w = input->getDims()[3], c = weight->getDims()[0], r = weight->getDims()[2], + s = weight->getDims()[3]; if (mode == PaddingMode::Same) { int oh = h / sh; int ow = w / sw; diff --git a/test/operators/test_conv_transposed_2d.cc b/test/operators/test_conv_transposed_2d.cc new file mode 100644 index 00000000..1806e3a2 --- /dev/null +++ b/test/operators/test_conv_transposed_2d.cc @@ -0,0 +1,115 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/perf_engine.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/conv.h" + +#include "test.h" + +namespace infini { + +TEST(ConvTransposed, ShapeInference) { + Runtime runtime = CpuRuntimeObj::getInstance(); + { // No pad: InfoGAN ConvTranspose_0 + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({1, 228, 1, 1}); + Tensor w0 = g->addTensor({228, 448, 2, 2}); + auto conv = g->addOp(i0, w0, nullptr, 0, 0); + EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 448, 2, 2})); + } + { // Padded, Strided: InfoGAN ConvTranspose_3 + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({1, 448, 2, 2}); + Tensor w0 = g->addTensor({448, 256, 4, 4}); + auto conv = g->addOp(i0, w0, nullptr, 1, 1, 2, 2); + EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 256, 4, 4})); + } + { // With output padding: GCN ConvTranspose_224 + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({1, 21, 7, 7}); + Tensor w0 = g->addTensor({21, 21, 3, 3}); + auto conv = g->addOp(i0, w0, nullptr, 1, 1, 2, 2, + 1, 1, 1, 1); + EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 21, 14, 14})); + } +} + +void testConvTransposedCudnn( + const std::function &generator, + vector ansVec) { + const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4}; + const int stride = 1, padding = 0, dilation = 1; + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({N, F, H, H}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({F, C, R, S}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr, + padding, padding, stride, + stride, dilation, dilation); + gCuda->dataMalloc(); + // Execute on CUDA + cuda->run(gCuda); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(ansVec)); +} + +TEST(ConvTransposed, cuDNN) { + testConvTransposedCudnn(IncrementalGenerator(), + vector{0., 0., 1., 2., 3., 0., 6., + 12., 18., 16., 8., 30., 36., 42., + 32., 16., 54., 60., 66., 48., 24., + 62., 67., 72., 45.}); +} + +TEST(ConvTransposed, tune) { + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 448, 2, 2}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({448, 256, 4, 4}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); + // print a tensor/operator/graph by print() + gCuda->print(); + // check record + auto kernelAttrs = + KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32}; + auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()}; + std::optional perfData = + PerfEngine::getInstance().getPerfData(perfKey); + ASSERT_TRUE(perfData.has_value()); +} + +} // namespace infini \ No newline at end of file