diff --git a/3rd-party/backward-cpp b/3rd-party/backward-cpp index 3bb9240c..f30744bc 160000 --- a/3rd-party/backward-cpp +++ b/3rd-party/backward-cpp @@ -1 +1 @@ -Subproject commit 3bb9240cb15459768adb3e7d963a20e1523a6294 +Subproject commit f30744bcf726ea3735df7ecf9e9de9ddac540283 diff --git a/3rd-party/googletest b/3rd-party/googletest index b796f7d4..e2239ee6 160000 --- a/3rd-party/googletest +++ b/3rd-party/googletest @@ -1 +1 @@ -Subproject commit b796f7d44681514f58a683a3a71ff17c94edb0c1 +Subproject commit e2239ee6043f73722e7aa812a459f54a28552929 diff --git a/3rd-party/nlohmann_json_cmake_fetchcontent b/3rd-party/nlohmann_json_cmake_fetchcontent index 13132dd3..6aebf092 160000 --- a/3rd-party/nlohmann_json_cmake_fetchcontent +++ b/3rd-party/nlohmann_json_cmake_fetchcontent @@ -1 +1 @@ -Subproject commit 13132dd361c8c5b5753983d5186cf54f689d90f9 +Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756 diff --git a/3rd-party/pybind11 b/3rd-party/pybind11 index 0bd8896a..1e3400b6 160000 --- a/3rd-party/pybind11 +++ b/3rd-party/pybind11 @@ -1 +1 @@ -Subproject commit 0bd8896a4010f2d91b2340570c24fa08606ec406 +Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index 3cf9c38b..a1859520 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -57,6 +57,11 @@ class GraphHandlerObj { Tensor convTransposed2d(Tensor input, Tensor weight, Tensor output, int ph, int pw, int sh, int sw, int dh, int dw, int oph, int opw); + Tensor convNHWC(Tensor input, Tensor weight, Tensor output, int ph, int pw, + int sh, int sw, int dh, int dw); + Tensor convTransposed2dNHWC(Tensor input, Tensor weight, Tensor output, + int ph, int pw, int sh, int sw, int dh, int dw, + int oph, int opw); Tensor matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act); Tensor batchNorm(Tensor input, Tensor output, Tensor mean, Tensor var, diff --git a/include/core/operator.h b/include/core/operator.h index 18382e7d..4cd6ae8b 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -11,6 +11,7 @@ enum class OpType { Matmul, ConvTrans, ConvTransNHWC, + ConvNHWC, G2BMM, GBMM, Pad, @@ -102,7 +103,10 @@ enum class OpType { Dropout, // MemBound = 300, - Any, + // + Conv2dReduce = 400, + Conv2dReduceTranspose, + Any }; using KernelAttrs = std::tuple; @@ -123,6 +127,7 @@ class OpRegistry { FOP(Matmul); FOP(ConvTrans); FOP(ConvTransNHWC); + FOP(ConvNHWC); FOP(G2BMM); FOP(GBMM); FOP(Pad); @@ -210,6 +215,9 @@ class OpRegistry { FOP(BitRightShift); // FOP(MemBound); + // + FOP(Conv2dReduce); + FOP(Conv2dReduceTranspose); FOP(Any); default: IT_ASSERT(false, "Unknown OpType " + diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index 6b9b8839..0a05444b 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -1,7 +1,6 @@ #pragma once #include "core/runtime.h" #include "cuda/cuda_common.h" -#include "nnet/dbg.h" namespace infini { diff --git a/include/nnet/nmutator.h b/include/nnet/nmutator.h index 91e1a4f5..e9946944 100644 --- a/include/nnet/nmutator.h +++ b/include/nnet/nmutator.h @@ -66,6 +66,9 @@ class NMutator : public Mutator { Graph transformGbmm(Operator op); Graph transformDialtedConv(Operator _op); Graph transformConv1xk(Operator op); + // Graph transformConv1xk(Operator op); + Graph transformConvToGEMMReduce(Operator _op); + Graph transformConvTranposeToGEMMReduce(Operator _op); Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize, Tensor output = nullptr); diff --git a/include/nnet/test_models.h b/include/nnet/test_models.h index b51b326a..c12faf51 100644 --- a/include/nnet/test_models.h +++ b/include/nnet/test_models.h @@ -6,6 +6,7 @@ namespace infini { Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId); +Graph getFSRCNNGraph(int batch, Runtime runtime); Graph getLongformer(Runtime runtime, int bs); vector runInfoGAN(int nLayers); Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId); diff --git a/include/operators/conv.h b/include/operators/conv.h index 449f4334..5d485938 100644 --- a/include/operators/conv.h +++ b/include/operators/conv.h @@ -111,7 +111,7 @@ class ConvBaseObj : public OperatorObj { auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); } auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); } int getChannelPerGroup() const { - if (type == OpType::ConvTransNHWC) { + if (type == OpType::ConvTransNHWC || type == OpType::ConvNHWC) { return inputs[1]->getDims()[3]; } else { return inputs[1]->getDims()[1]; @@ -149,6 +149,25 @@ class ConvObj : public ConvBaseObj { void setAuxilaryAttributes(PaddingMode mode) override; }; +class ConvNHWCObj : public ConvBaseObj { + public: + ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph, + int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1, + Tensor bias = nullptr, ActType act = ActType::None); + // Constructors for setting padding mode + ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1, + int dh = 1, int dw = 1, Tensor bias = nullptr, + ActType act = ActType::None); + OP_CLONE(ConvNHWCObj); + + optional> inferShape(const TensorVec &inputs) const override; + int getNumGroups() const override { return c / getChannelPerGroup(); } + + private: + void setAuxilaryAttributes(PaddingMode mode) override; +}; + class ConvBackwardFilterObj : public ConvBaseObj { private: ActType act; @@ -220,6 +239,7 @@ class ConvTransposed2dNHWCObj : public ConvBaseObj { optional> inferShape(const TensorVec &inputs) const override; int getNumGroups() const override { return group; } + std::pair getOutputPadding() const { return {oph, opw}; } private: void setAuxilaryAttributes(PaddingMode mode) override; diff --git a/include/operators/conv2dreduce.h b/include/operators/conv2dreduce.h new file mode 100644 index 00000000..4db547ef --- /dev/null +++ b/include/operators/conv2dreduce.h @@ -0,0 +1,62 @@ +#pragma once +#include "core/operator.h" + +namespace infini { + +class Conv2dReduceBase : public OperatorObj { + protected: + Tensor bias; + int ph, pw; + int sh, sw; + int dh, dw; + int n, h, w, f, r, s; // c has been reduced + bool PReLU; + float paramReLU; + + public: + Conv2dReduceBase(OpType opType, Tensor input, Tensor bias, Tensor output, + bool PReLU_, float paramReLU_, int ph_, int pw_, + int sh_ = 1, int sw_ = 1, int dh_ = 1, int dw_ = 1); + + std::string toString() const override; + int numInputs() const override { return 2; } + int numOutputs() const override { return 1; } + + int getDh() const { return dh; } + int getDw() const { return dw; } + int getPh() const { return ph; } + int getPw() const { return pw; } + int getSh() const { return sh; } + int getSw() const { return sw; } + bool getPReLU() const { return PReLU; } + float getParamReLU() const { return paramReLU; } + + Tensor getBias() const { return bias; } + + // optional> inferShape(const TensorVec &inputs) const + // override; + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +class Conv2dReduce : public Conv2dReduceBase { + public: + Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias, Tensor output, + bool PReLU_, float paramReLU_, int ph_, int pw_, int sh_ = 1, + int sw_ = 1, int dh_ = 1, int dw_ = 1); + OP_CLONE(Conv2dReduce); + optional> inferShape(const TensorVec &inputs) const override; +}; + +class Conv2dReduceTranspose : public Conv2dReduceBase { + public: + Conv2dReduceTranspose(GraphObj *graph, Tensor input, Tensor bias, + Tensor output, bool PReLU_, float paramReLU_, int ph_, + int pw_, int sh_ = 1, int sw_ = 1, int dh_ = 1, + int dw_ = 1); + OP_CLONE(Conv2dReduceTranspose); + optional> inferShape(const TensorVec &inputs) const override; +}; +} // namespace infini \ No newline at end of file diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 83a4fceb..2aeb86cf 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -709,7 +709,7 @@ class OnnxStub: ctx.push_output(f"{name}_{i}_{it.guid()}", it) for (i, it) in enumerate(op.outputs()) ] - if ty == backend.OpType.Conv: + if ty == backend.OpType.Conv or ty == backend.OpType.ConvNHWC: ph, pw, dh, dw, sh, sw = backend.conv_attrs_of(op) ctx.push_node( make_node( @@ -723,7 +723,7 @@ class OnnxStub: group=op.inputs()[0].shape()[1] // op.inputs()[1].shape()[1], ) ) - elif ty == backend.OpType.ConvTrans: + elif ty == backend.OpType.ConvTrans or ty == backend.OpType.ConvTransNHWC: ph, pw, sh, sw, dh, dw, oph, opw = backend.conv_trans_attrs_of(op) ctx.push_node( make_node( @@ -895,6 +895,26 @@ class OnnxStub: domain="nnet", ) ) + elif ty == backend.OpType.Conv2dReduce: + ctx.push_node( + make_node( + ty.name, + inputs, + outputs, + name, + domain="nnet", + ) + ) + elif ty == backend.OpType.Conv2dReduceTranspose: + ctx.push_node( + make_node( + ty.name, + inputs, + outputs, + name, + domain="nnet", + ) + ) elif ty == backend.OpType.MemBound: ctx.push_node( make_node( diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index d19d8fe7..bf6b3b2b 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -57,6 +57,38 @@ Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight, } } +Tensor GraphHandlerObj::convNHWC(Tensor input, Tensor weight, Tensor output, + int ph, int pw, int sh, int sw, int dh, + int dw) { + if (output) { + g->addOpWithOutputs(std::move(input), std::move(weight), + output, ph, pw, sh, sw, dh, dw); + return output; + } else { + return g + ->addOp(std::move(input), std::move(weight), output, + ph, pw, sh, sw, dh, dw) + ->getOutput(); + } +} + +Tensor GraphHandlerObj::convTransposed2dNHWC(Tensor input, Tensor weight, + Tensor output, int ph, int pw, + int sh, int sw, int dh, int dw, + int oph, int opw) { + if (output) { + g->addOpWithOutputs( + std::move(input), std::move(weight), output, ph, pw, sh, sw, dh, dw, + oph, opw); + return output; + } else { + return g->addOp(std::move(input), + std::move(weight), output, ph, + pw, sh, sw, dh, dw, oph, opw) + ->getOutput(); + } +} + Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA, bool transB, Tensor bias, ActType act) { if (y) { diff --git a/src/core/operator.cc b/src/core/operator.cc index 47bef3df..733ed018 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -26,7 +26,8 @@ bool OperatorObj::isConcatOp() const { return type == OpType::Concat; } bool OperatorObj::isComputeOp() const { return type == OpType::Conv || type == OpType::Matmul || type == OpType::ConvTrans || type == OpType::ConvTransNHWC || - type == OpType::G2BMM || type == OpType::GBMM; + type == OpType::G2BMM || type == OpType::GBMM || + type == OpType::ConvNHWC; } bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; } diff --git a/src/core/search_engine.cc b/src/core/search_engine.cc index 03561f17..657176b2 100644 --- a/src/core/search_engine.cc +++ b/src/core/search_engine.cc @@ -1,6 +1,7 @@ #include "core/search_engine.h" #include "core/hash.h" #include "core/runtime.h" +#include "ffi/ffi_callback.h" #include "nnet/dbg.h" #include @@ -348,8 +349,8 @@ std::vector SearchEngine::searchMutation(const MetaGraph &metaGraph) { // // HACK: only try the first one for debug // if (mutatedGraphs.size() > 2) // mutatedGraphs.resize(2); - if (mutatedGraphs.size() >= 2) - mutatedGraphs = {mutatedGraphs[1]}; + // if (mutatedGraphs.size() >= 2) + // mutatedGraphs = {mutatedGraphs[1]}; for (auto graph : graphs) { for (auto mutatedGraph : mutatedGraphs) { std::vector ops; @@ -455,6 +456,9 @@ std::vector SearchEngine::partitionGraph(const Graph graph) { } double SearchEngine::getEstimatedGraphPerf(Graph graph) { + // dbg(graph); + // // hkz + // callback::exportONNX(graph, "a.onnx"); return runtimeExec->getPerfTime(graph, false, true, true); } @@ -502,6 +506,7 @@ Graph SearchEngine::fuseVertically(const Graph &graph) { auto bestGraph = make_ref(runtimeExec, chainOps); // Eliminate transpose and reshape operators + // FIXME: current Relu only support 3D and 4D tensors if (auto eliminatedGraph = mutator->eliminateVertically( make_ref(runtimeExec, chainOps))) bestGraph = eliminatedGraph; diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 6d526eb9..da65104a 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -3,6 +3,7 @@ #include "core/perf_engine.h" #include "core/runtime.h" #include "cuda_profiler_api.h" +#include "nnet/dbg.h" #include "operators/conv.h" #include "operators/matmul.h" #ifdef INFINI_USE_TVM diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index e9be7fff..2e53af8a 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -69,6 +69,7 @@ void export_values(py::module &m) { .VALUE(OpType, Matmul) .VALUE(OpType, ConvTrans) .VALUE(OpType, ConvTransNHWC) + .VALUE(OpType, ConvNHWC) .VALUE(OpType, G2BMM) .VALUE(OpType, GBMM) .VALUE(OpType, Pad) @@ -100,6 +101,8 @@ void export_values(py::module &m) { .VALUE(OpType, Abs) .VALUE(OpType, Resize) .VALUE(OpType, Dropout) + .VALUE(OpType, Conv2dReduce) + .VALUE(OpType, Conv2dReduceTranspose) .VALUE(OpType, MemBound) .VALUE(OpType, Any) .export_values(); @@ -144,17 +147,32 @@ static Ref intelcpu_runtime() { return make_ref(); } #endif static std::tuple conv_attrs_of(Operator op) { - IT_ASSERT(op->getOpType() == OpType::Conv); - auto conv = dynamic_cast(op.get()); + IT_ASSERT(op->getOpType() == OpType::Conv || + op->getOpType() == OpType::ConvNHWC); + auto conv = dynamic_cast(op.get()); return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(), conv->getDw(), conv->getSh(), conv->getSw()); } static std::tuple conv_trans_attrs_of(Operator op) { - IT_ASSERT(op->getOpType() == OpType::ConvTrans); - auto conv = dynamic_cast(op.get()); - auto [oph, opw] = conv->getOutputPadding(); + IT_ASSERT(op->getOpType() == OpType::ConvTrans || + op->getOpType() == OpType::ConvTransNHWC); + auto conv = dynamic_cast(op.get()); + int oph, opw; + + if (op->getOpType() == OpType::ConvTrans) { + auto _conv = dynamic_cast(op.get()); + auto output_pad = _conv->getOutputPadding(); + oph = output_pad.first; + opw = output_pad.second; + } else { + auto _conv = dynamic_cast(op.get()); + auto output_pad = _conv->getOutputPadding(); + oph = output_pad.first; + opw = output_pad.second; + } + return std::make_tuple(conv->getPh(), conv->getPw(), conv->getDh(), conv->getDw(), conv->getSh(), conv->getSw(), oph, opw); @@ -328,6 +346,9 @@ void init_graph_builder(py::module &m) { "tensor_type"_a = TensorType::Other) .def("conv", &Handler::conv, policy::move) .def("convTransposed2d", &Handler::convTransposed2d, policy::move) + .def("convNHWC", &Handler::convNHWC, policy::move) + .def("convtransposed2dNHWC", &Handler::convTransposed2dNHWC, + policy::move) .def("matmul", &Handler::matmul, policy::move) .def("batchNorm", &Handler::batchNorm, policy::move) .def("maxPool", &Handler::maxPool, policy::move) @@ -386,6 +407,7 @@ void export_test_model(py::module &m) { #ifdef USE_CUDA m.def("runInfoGAN", &runInfoGAN) .def("getGANGraph", &getGANGraph) + .def("getFSRCNNGraph", &getFSRCNNGraph) .def("getLongformer", &getLongformer) .def("getConvtransposedNHWC", &getConvtransposedNHWC) .def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a, diff --git a/src/kernels/cuda/conv.cc b/src/kernels/cuda/conv.cc index c020ed33..cf214cb3 100644 --- a/src/kernels/cuda/conv.cc +++ b/src/kernels/cuda/conv.cc @@ -52,7 +52,7 @@ class convCudnn : public Kernel { cudnnFilterDescriptor_t, cudnnTensorDescriptor_t, cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t, cudnnTensorDescriptor_t> - createCuDNNDescriptor(const Ref &op, + createCuDNNDescriptor(const Ref &op, const ConvCuDnnPerfRecord &record) const { void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const knData = (op->getInputs(1)->getRawDataPtr()); @@ -68,15 +68,23 @@ class convCudnn : public Kernel { int channelsPerGrp = cpg, channels = c; + // set input format + cudnnTensorFormat_t tensorFormat = (op->getOpType() == OpType::ConvNHWC) + ? CUDNN_TENSOR_NHWC + : CUDNN_TENSOR_NCHW; + // get inputs cudnnTensorDescriptor_t inDesc; checkCudnnError(cudnnCreateTensorDescriptor(&inDesc)); checkCudnnError(cudnnSetTensor4dDescriptor( - inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w)); + inDesc, tensorFormat, CUDNN_DATA_FLOAT, n, channels, h, w)); // get kernels cudnnFilterDescriptor_t knDesc; checkCudnnError(cudnnCreateFilterDescriptor(&knDesc)); + // FIXME: filter data layout is not changed with input data layout + // since FCRS shows better performance for NHWC inputs in some cases. + // This should be tunable. checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s)); @@ -84,7 +92,7 @@ class convCudnn : public Kernel { cudnnTensorDescriptor_t biasDesc; checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc)); checkCudnnError(cudnnSetTensor4dDescriptor( - biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1)); + biasDesc, tensorFormat, CUDNN_DATA_FLOAT, 1, f, 1, 1)); // get convlution descriptor cudnnConvolutionDescriptor_t convDesc; @@ -125,18 +133,25 @@ class convCudnn : public Kernel { convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw)); cudnnTensorDescriptor_t outDesc; checkCudnnError(cudnnCreateTensorDescriptor(&outDesc)); - checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW, - CUDNN_DATA_FLOAT, outn, outc, - outh, outw)); - IT_ASSERT((vector{outn, outc, outh, outw}) == - op->getOutput()->getDims(), - "cuDNN output shape mismatches with OP output shape"); + checkCudnnError(cudnnSetTensor4dDescriptor( + outDesc, tensorFormat, CUDNN_DATA_FLOAT, outn, outc, outh, outw)); + + if (op->getOpType() == OpType::ConvNHWC) { + IT_ASSERT((vector{outn, outh, outw, outc}) == + op->getOutput()->getDims(), + "cuDNN output shape mismatches with OP output shape"); + } else { + IT_ASSERT((vector{outn, outc, outh, outw}) == + op->getOutput()->getDims(), + "cuDNN output shape mismatches with OP output shape"); + } return tuple(inData, knData, outData, inDesc, knDesc, biasDesc, convDesc, actDesc, outDesc); } - bool cuDNNUnfused(const Ref &op, const ConvCuDnnPerfRecord &record, + bool cuDNNUnfused(const Ref &op, + const ConvCuDnnPerfRecord &record, const CudaRuntimeObj *context) const { cudnnStatus_t stat; @@ -220,11 +235,12 @@ class convCudnn : public Kernel { ConvCuDnnPerfRecordObj ret; ret.time = std::numeric_limits::max(); auto context = dynamic_cast(_context); - auto op = as(_op); + auto op = as(_op); + int try_algo = op->getOpType() == OpType::ConvNHWC ? 2 : N_ALGO; // Both modes have the same performance. Only run cross-correlation. for (int mode = 1; mode < 2; mode++) { // Try every possible algorithm of convolution - for (int algo = 0; algo < N_ALGO; algo++) { + for (int algo = 0; algo < try_algo; algo++) { auto recordRef = make_ref(); auto &record = *recordRef; record.mode = mode; @@ -283,7 +299,7 @@ class convCudnn : public Kernel { void compute(const Operator &_op, const PerfRecord &_record, const RuntimeObj *_context) const override { - auto op = as(_op); + auto op = as(_op); auto record = as(_record); auto context = dynamic_cast(_context); bool success = cuDNNUnfused(op, record, context); @@ -294,5 +310,8 @@ class convCudnn : public Kernel { REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn, "Conv_cuDNN_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::ConvNHWC, DataType::Float32, convCudnn, + "ConvNHWC_cuDNN_CUDA_Float32"); + REGISTER_CONSTRUCTOR(1, ConvCuDnnPerfRecordObj::from_json); } // namespace infini diff --git a/src/kernels/cuda/conv2dreduce.cc b/src/kernels/cuda/conv2dreduce.cc new file mode 100644 index 00000000..aaf1b5f6 --- /dev/null +++ b/src/kernels/cuda/conv2dreduce.cc @@ -0,0 +1,44 @@ +#include "operators/conv2dreduce.h" +#include "cuda/cuda_conv2dreduce.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" + +namespace infini { + +class Conv2dReduceCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, const RuntimeObj *_context) const { + auto op = as(_op); + float *const input = (op->getInputs(0)->getRawDataPtr()); + float *const bias = + op->getBias() ? (op->getBias()->getRawDataPtr()) : nullptr; + float *const output = (op->getOutput()->getRawDataPtr()); + + auto dim = op->getInputs(0)->getDims(); + int n = dim[0], h = dim[1], w = dim[2], f = dim[3], r = dim[4], + s = dim[5]; + int dh = op->getDh(), dw = op->getDw(); + int sh = op->getSh(), sw = op->getSw(); + int ph = op->getPh(), pw = op->getPw(); + auto odim = op->getOutput()->getDims(); + int oh = odim[1], ow = odim[2]; + bool PReLU = op->getPReLU(); + // float paramReLU = op->getParamReLU(); + + auto opType = op->getOpType(); + + if (opType == OpType::Conv2dReduce) { + conv2dreduce_kernel(input, bias, output, PReLU, n, h, w, f, r, s, + oh, ow, ph, pw, sh, sw, dh, dw); + } else { + convTranspose2dreduce_kernel(input, bias, output, PReLU, n, h, w, f, + r, s, oh, ow, ph, pw, sh, sw, dh, dw); + } + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduce, DataType::Float32, + Conv2dReduceCuda, "Conv2dReduce_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduceTranspose, DataType::Float32, + Conv2dReduceCuda, "Conv2dReduceTranspose_CUDA_Float32"); + +} // namespace infini diff --git a/src/kernels/cuda/conv2dreduce.cu b/src/kernels/cuda/conv2dreduce.cu index 49747137..a7026c16 100644 --- a/src/kernels/cuda/conv2dreduce.cu +++ b/src/kernels/cuda/conv2dreduce.cu @@ -1,4 +1,5 @@ #include "cuda/cuda_common.h" +#include "nnet/dbg.h" using dtype = float; @@ -40,18 +41,71 @@ __global__ void conv2dreduce_kernel_(float *__restrict__ input, output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm; } } +__global__ void convTranspose2dreduce_kernel2_( + float *__restrict__ input, float *__restrict__ bias, + float *__restrict__ output, const bool PReLU, const int n, const int f, + const int h, const int w, const int oh, const int ow, const int r, + const int s, const int ph, const int pw, const int dh, const int dw, + const int sh, const int sw) { + int warp_id = (blockDim.x / 32) * blockIdx.x + threadIdx.x / 32; + int lane = threadIdx.x % 32; + int nid = warp_id / (f * oh * ow); + int fid = (warp_id - nid * (f * oh * ow)) / (oh * ow); + int hid = (warp_id - nid * (f * oh * ow) - fid * (oh * ow)) / ow; + int wid = warp_id % ow; + if (hid >= oh || wid >= ow || nid > n || fid > f) + return; + + const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk, + nchunck = h * hchunk; + float *nfinput = input + nid * nchunck + fid * fchunck; + // view as conv, the true ph and pw + int tph = r - ph - 1, tpw = s - pw - 1; + int th = (h - 1) * sh + 1, tw = (w - 1) * sw + 1; + + float imm = 0.0; + int ihst = hid - tph; + int iwst = wid - tpw; + for (int idx = lane; idx < r * s; idx += 32) { + int ri = idx / s; + int si = idx % s; + int ihid = ihst + r - ri - 1; + int iwid = iwst + s - si - 1; + if (ihid >= 0 && ihid < th && iwid >= 0 && iwid < tw && + (ihid % sh == 0) && (iwid % sw == 0)) { + imm += *(nfinput + (ihid / sh) * hchunk + (iwid / sw) * wchunk + + ri * s + si); + } + } + + for (int k = 16; k > 0; k >>= 1) { + imm += __shfl_down_sync(0xffffffff, imm, k); // sum + } + if (lane == 0) { + if (bias) { + imm += bias[fid]; + } + if (PReLU) { + imm = imm > 0.0 ? imm : 0.0; + } + output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm; + } +} __global__ void convTranspose2dreduce_kernel_( float *__restrict__ input, float *__restrict__ bias, float *__restrict__ output, const bool PReLU, const int n, const int f, const int h, const int w, const int oh, const int ow, const int r, const int s, const int ph, const int pw, const int dh, const int dw, - const int sh, const int sw) { + const int sh, const int sw, const int block_x_num, const int block_y_num) { // assert dh = dw = 1 - int nid = blockIdx.x, fid = blockIdx.y; - int hid = threadIdx.x, wid = threadIdx.y; + int nid = blockIdx.x / block_x_num, fid = blockIdx.y / block_y_num; + int hid = (blockIdx.x % block_x_num) * blockDim.x + threadIdx.x, + wid = (blockIdx.y % block_y_num) * blockDim.y + threadIdx.y; + if (hid >= oh || wid >= ow) + return; const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk, - nchunck = n * hchunk; + nchunck = h * hchunk; float *nfinput = input + nid * nchunck + fid * fchunck; // view as conv, the true ph and pw int tph = r - ph - 1, tpw = s - pw - 1; @@ -162,8 +216,22 @@ void convTranspose2dreduce_kernel(float *input, float *bias, float *output, reduce_4x4<<<(M * N + 127) / 128, 128>>>(input, output, act, n, f, oh, ow, h, w); } else { - puts("why use this conv2dreduce"); - convTranspose2dreduce_kernel_<<>>( + // puts("why use this conv2dreduce"); + // block.x = 32; + // block.y = 32; + // int block_x_num = (oh + block.x - 1) / block.x; + // int block_y_num = (ow + block.y - 1) / block.y; + // grid.x = n * (block_x_num); + // grid.y = f * (block_y_num); + // convTranspose2dreduce_kernel_<<>>( + // input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw, + // dh, dw, sh, sw, block_x_num, block_y_num); + + block.x = 128; + block.y = 1; + grid.x = (n * f * ow * oh + block.x / 32 - 1) / (block.x / 32); + grid.y = 1; + convTranspose2dreduce_kernel2_<<>>( input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw, dh, dw, sh, sw); } diff --git a/src/kernels/cuda/unary.cc b/src/kernels/cuda/unary.cc index 05b02e88..ee4e590b 100644 --- a/src/kernels/cuda/unary.cc +++ b/src/kernels/cuda/unary.cc @@ -31,8 +31,9 @@ class ActivationCudnn : public CudaKernelWithoutConfig { n = dim[0], c = dim[1], h = dim[2], w = dim[3]; } else if (dim.size() == 3) { n = 1, c = dim[0], h = dim[1], w = dim[2]; - } else + } else { IT_TODO_HALT(); + } // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc)); diff --git a/src/nnet/App/test_models.cc b/src/nnet/App/test_models.cc index 0780ef4e..6b832ca2 100644 --- a/src/nnet/App/test_models.cc +++ b/src/nnet/App/test_models.cc @@ -86,6 +86,60 @@ Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId) { return g; } +// NHWC +Graph getFSRCNNGraph(int batch, Runtime runtime) { + // n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU + const DetailedConfigs fsrcnn_config = { + {batch, 1, 32, 32, 56, 5, 5, 1, 2, 1, true}, + {batch, 56, 32, 32, 12, 1, 1, 1, 0, 1, true}, + {batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false}, + {batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false}, + {batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, false}, + {batch, 12, 32, 32, 12, 3, 3, 1, 1, 1, true}, + {batch, 12, 32, 32, 56, 1, 1, 1, 0, 1, true}, + {batch, 56, 32, 32, 1, 9, 9, 1, 3, 4, false} // ConvTransNHWC + // n, f, h, w, c, r, s, stride, pad, dilation, has_pReLU + }; + + Graph g = make_ref(runtime); + + Tensor input; + { + auto &[n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU] = + fsrcnn_config[0]; + input = g->addTensor({batch, h, w, c}, DataType::Float32, + TensorType::Input); + } + + for (int i = 0; i < (int)fsrcnn_config.size() - 1; ++i) { + // auto [channel, kernelSize, pad, stride, tanh] = configs[i]; + auto &[n, c, h, w, f, r, s, stride, pad, dilation, has_pReLU] = + fsrcnn_config[i]; + IT_ASSERT(input->getDims()[3] == c); + auto weight = g->addTensor({f, r, s, c}, DataType::Float32, + TensorType::Initialized); // f, r, s, c + input = g->addOp(input, weight, nullptr, pad, pad, stride, + stride, 1, 1) + ->getOutput(); + if (has_pReLU) { + input = g->addOp(input, nullptr)->getOutput(); + } + } + + // last operator is a ConvTransNHWC + { + auto &[n, f, h, w, c, r, s, stride, pad, dilation, has_pReLU] = + fsrcnn_config[fsrcnn_config.size() - 1]; + IT_ASSERT(input->getDims()[3] == f); + auto weight = g->addTensor({f, r, s, c}, DataType::Float32, + TensorType::Initialized); // f, r, s, c + input = g->addOp(input, weight, nullptr, pad, + pad, stride, stride, 1, 1) + ->getOutput(); + } + return g; +} + Graph getLongformer(Runtime runtime, int bs) { const int seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4; const int hidden = featlen, hiddenPerHead = hidden / heads; @@ -165,8 +219,8 @@ Graph getLongformer(Runtime runtime, int bs) { g->addOpWithOutputs(q0, q1); g->addOpWithOutputs(k0, k1); g->addOpWithOutputs(v0, v1); - // For example, when perm=(1, 0, 2), given an input tensor of shape (1, 2, - // 3), the output shape will be (2, 1, 3). + // For example, when perm=(1, 0, 2), given an input tensor of shape (1, + // 2, 3), the output shape will be (2, 1, 3). g->addOpWithOutputs(q1, q2, vector{0, 2, 1, 3}); g->addOpWithOutputs(k1, k2, vector{0, 2, 1, 3}); g->addOpWithOutputs(v1, v2, vector{0, 2, 1, 3}); diff --git a/src/nnet/nmutator.cc b/src/nnet/nmutator.cc index 5ede0a79..06510e50 100644 --- a/src/nnet/nmutator.cc +++ b/src/nnet/nmutator.cc @@ -11,6 +11,7 @@ #include "operators/GBMM.h" #include "operators/any.h" #include "operators/conv.h" +#include "operators/conv2dreduce.h" #include "operators/matmul.h" #include "operators/membound.h" #include "operators/reduce_mean.h" @@ -98,6 +99,10 @@ void NMutator::runSingleOp(Graph in_graph, std::vector &out_graphs) { out_graphs.emplace_back(g); return; } + if (infini::Graph g = transformConv1xk(computeOps[0])) { + out_graphs.emplace_back(g); + return; + } if (Graph g = transformG2bmm(computeOps[0])) { out_graphs.emplace_back(g); return; @@ -110,7 +115,12 @@ void NMutator::runSingleOp(Graph in_graph, std::vector &out_graphs) { out_graphs.emplace_back(g); return; } - if (infini::Graph g = transformConv1xk(computeOps[0])) { + if (infini::Graph g = transformConvToGEMMReduce(computeOps[0])) { + out_graphs.emplace_back(g); + return; + } + + if (infini::Graph g = transformConvTranposeToGEMMReduce(computeOps[0])) { out_graphs.emplace_back(g); return; } @@ -522,7 +532,9 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) { if (h != 1 || w != 1) return {}; IT_ASSERT_TODO(ph == pw); - IT_ASSERT_TODO(tie(sh, sw) == tuple(1, 1)); + if (tie(sh, sw) != tuple(1, 1)) { + return nullptr; + } IT_ASSERT_TODO(tie(dh, dw) == tuple(1, 1)); auto g = make_ref(runtime); // NHWF @@ -543,6 +555,80 @@ Graph NMutator::transformConvtransposed1x1(Operator _op) { return g; } +Graph NMutator::transformConvToGEMMReduce(Operator _op) { + auto op = as(_op); + if (!op) + return nullptr; + const auto &A = op->getInputs()[0]; + const auto &W = op->getInputs()[1]; + const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS(); + const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const Shape inputDims = op->getInputs(0)->getDims(); + const Shape weightDims = op->getInputs(1)->getDims(); + const Shape outputDims = op->getOutput()->getDims(); + IT_ASSERT(weightDims[0] == f); + IT_ASSERT(weightDims[1] == r); + IT_ASSERT(weightDims[2] == s); + IT_ASSERT(weightDims[3] == c); + IT_ASSERT(inputDims[0] == n); + IT_ASSERT(inputDims[1] == h); + IT_ASSERT(inputDims[2] == w); + IT_ASSERT(inputDims[3] == c); + const DataType dtype = A->getDType(); + auto g = make_ref(runtime); + auto newA = g->addTensor( + {inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype); + + // // If use Matmul with transpose 0,0 + // auto newW = g->addTensor( + // {weightDims[3], weightDims[0] * weightDims[1] * weightDims[2]}, dtype); + + // If use Matmul with transpose 0, 1 + auto newW = g->addTensor( + {weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]}, + dtype); + g->addOpWithOutputs(g->cloneTensor(A), newA, newA->getDims()); + g->addOpWithOutputs(g->cloneTensor(W), newW, newW->getDims()); + Tensor newO = g->addOp(newA, newW, nullptr, 0, 1)->getOutput(); + auto new1 = g->addTensor({n, h, w, f, r, s}, dtype); + g->addOpWithOutputs(newO, new1, new1->getDims()); + g->addOpWithOutputs( + new1, nullptr, g->cloneTensor(op->getOutput()), false, 0.f, ph, pw); + return g; +} + +Graph NMutator::transformConvTranposeToGEMMReduce(Operator _op) { + auto op = as(_op); + if (!op) + return nullptr; + const auto &A = op->getInputs()[0]; + const auto &W = op->getInputs()[1]; + // f is the de-facto input channel for ConvTranspose + const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS(); + const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); + const Shape inputDims = op->getInputs(0)->getDims(); + const Shape weightDims = op->getInputs(1)->getDims(); + const Shape outputDims = op->getOutput()->getDims(); + const DataType dtype = A->getDType(); + auto g = make_ref(runtime); + auto newA = g->addTensor( // [N,H,W,F] + {inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype); + auto newW = g->addTensor( // [F, CRS] + {weightDims[0], weightDims[1] * weightDims[2] * weightDims[3]}, + dtype); // HACK: this should be a transpose + + g->addOpWithOutputs(g->cloneTensor(A), newA, newA->getDims()); + g->addOpWithOutputs(g->cloneTensor(W), newW, newW->getDims()); + // newO [NHW, CRS] + Tensor newO = g->addOp(newA, newW, nullptr, 0, 0)->getOutput(); + auto new1 = g->addTensor({n, h, w, c, r, s}, dtype); + g->addOpWithOutputs(newO, new1, new1->getDims()); + // [NHW, CRS] -> [N,H,W,C] + g->addOpWithOutputs( + new1, nullptr, g->cloneTensor(op->getOutput()), false, 0.f, ph, pw); + return g; +} + // Graph NMutator::transformConvtransposed(Operator _op) { // auto op = as(_op); // if (!op) diff --git a/src/operators/conv.cc b/src/operators/conv.cc index 36d97081..45e7a873 100644 --- a/src/operators/conv.cc +++ b/src/operators/conv.cc @@ -114,6 +114,75 @@ optional> ConvObj::inferShape(const TensorVec &inputs) const { return {{{on, oc, oh, ow}}}; } +void ConvNHWCObj::setAuxilaryAttributes(PaddingMode mode) { + const Tensor &input = inputs[0]; + const Tensor &weight = inputs[1]; + n = input->getDims()[0], c = input->getDims()[3], h = input->getDims()[1], + w = input->getDims()[2], f = weight->getDims()[0], r = weight->getDims()[1], + s = weight->getDims()[2]; + if (mode == PaddingMode::Same) { + int oh = h / sh; + int ow = w / sw; + ph = (h - oh * sh + (r - sh) * dh) / 2; + pw = (w - ow * sw + (s - sw) * dw) / 2; + } else if (mode == PaddingMode::Valid) { + ph = pw = 0; + } +} + +ConvNHWCObj::ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias, + ActType act) + : ConvBaseObj(OpType::ConvNHWC, {input, weight}, output, ph, pw, sh, sw, dh, dw, + input, weight, act) { + if (bias) + IT_TODO_HALT(); + setAuxilaryAttributes(PaddingMode::Other); + IT_ASSERT(checkValid(graph)); +} + +ConvNHWCObj::ConvNHWCObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, + PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias, + ActType act) + : ConvBaseObj(OpType::ConvNHWC, {input, weight}, output, mode, sh, sw, dh, dw, + input, weight, act) { + if (bias) + IT_TODO_HALT(); + setAuxilaryAttributes(mode); + IT_ASSERT(checkValid(graph)); +} + +optional> ConvNHWCObj::inferShape(const TensorVec &inputs) const { + const auto &input = inputs[0], &weight = inputs[1]; + auto n = input->getDims()[0]; + auto h = input->getDims()[1]; + auto w = input->getDims()[2]; + auto f = weight->getDims()[0]; + auto r = weight->getDims()[1]; + auto s = weight->getDims()[2]; + int on = n, oc = f; + int oh = 0, ow = 0; + // For NCHW+FCRS layout, C of input is divisable by C of weight + if (input->getDims()[3] % weight->getDims()[3] != 0) + return {}; + // Set padding size + if (padding == PaddingMode::Other) { + oh = (h - (r - sh) * dh + ph * 2) / sh; + ow = (w - (s - sw) * dw + pw * 2) / sw; + } else if (padding == PaddingMode::Same) { + oh = h / sh; + ow = w / sw; + // ph = (h - oh * sh + (r - sh) * dh) / 2; + // pw = (w - ow * sw + (s - sw) * dw) / 2; + } else if (padding == PaddingMode::Valid) { + int ph = 0; + int pw = 0; + oh = (h - (r - sh) * dh + ph * 2) / sh; + ow = (w - (s - sw) * dw + pw * 2) / sw; + } + return {{{on, oh, ow, oc}}}; +} + ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph, int pw, int sh, int sw, int dh, int dw, diff --git a/src/operators/conv2dreduce.cc b/src/operators/conv2dreduce.cc new file mode 100644 index 00000000..86d3700d --- /dev/null +++ b/src/operators/conv2dreduce.cc @@ -0,0 +1,98 @@ +#include "operators/conv2dreduce.h" + +namespace infini { + +Conv2dReduceBase::Conv2dReduceBase(OpType opType, Tensor input, Tensor bias_, + Tensor output, bool PReLU_, float paramReLU_, + int ph_, int pw_, int sh_, int sw_, int dh_, + int dw_) + : OperatorObj(opType, {input}, {output}), bias(bias_), ph(ph_), pw(pw_), + sh(sh_), sw(sw_), dh(dh_), dw(dw_), PReLU(PReLU_), paramReLU(paramReLU_) { + // expect input shape is (n, h, w, f, r, s) + auto inputShape = input->getDims(); + IT_ASSERT(inputShape.size() == 6); + n = inputShape[0]; + h = inputShape[1]; + w = inputShape[2]; + f = inputShape[3]; + r = inputShape[4]; + s = inputShape[5]; + + if (bias) { + auto biasShape = bias->getDims(); + IT_ASSERT(biasShape.size() == 1); + IT_ASSERT(biasShape[0] == f); + } +} + +std::string Conv2dReduceBase::toString() const { + std::ostringstream os; + os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]"; + os << "("; + if (inputs.size() == 2) { + os << vecToString(inputs[0]->getDims()) << ","; + os << vecToString(inputs[1]->getDims()) << ","; + } else { + os << vecToString(inputs[0]->getDims()) << ","; + } + os << "p=[" << ph << "," << pw << "],"; + os << "s=[" << sh << "," << sw << "],"; + os << "d=[" << dh << "," << dw << "],"; + os << "PReLU=" << (PReLU ? "true" : "false") << ","; + // os << "act=" << enum_to_underlying(act) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + if (bias != nullptr) { + os << "bias=" << bias->getGuid() << ","; + } + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +std::vector Conv2dReduceBase::getWorkloadVector() const { + return {enum_to_underlying(type), n, h, w, f, r, s, ph, pw, sh, sw, dh, dw}; +} + +std::vector Conv2dReduceBase::getOpAttrVector() const { + return {enum_to_underlying(type), ph, pw, sh, sw, dh, dw}; +} + +Conv2dReduce::Conv2dReduce(GraphObj *graph, Tensor input, Tensor bias, + Tensor output, bool PReLU_, float paramReLU_, + int ph_, int pw_, int sh_, int sw_, int dh_, int dw_) + : Conv2dReduceBase(OpType::Conv2dReduce, input, bias, output, PReLU_, + paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) { + IT_ASSERT(checkValid(graph)); +} + +optional> +Conv2dReduce::inferShape(const TensorVec &inputs) const { + // const auto &input = inputs[0], &bias = inputs[1]; + int on = n, of = f; + int oh = (h + ph * 2 - dh * (r - 1) - 1) / sh + 1; + int ow = (w + pw * 2 - dw * (s - 1) - 1) / sw + 1; + + return {{{on, oh, ow, of}}}; +} + +Conv2dReduceTranspose::Conv2dReduceTranspose(GraphObj *graph, Tensor input, + Tensor bias, Tensor output, + bool PReLU_, float paramReLU_, + int ph_, int pw_, int sh_, int sw_, + int dh_, int dw_) + : Conv2dReduceBase(OpType::Conv2dReduceTranspose, input, bias, output, + PReLU_, paramReLU_, ph_, pw_, sh_, sw_, dh_, dw_) { + IT_ASSERT(dh_ == 1); + IT_ASSERT(dw_ == 1); + IT_ASSERT(checkValid(graph)); +} + +optional> +Conv2dReduceTranspose::inferShape(const TensorVec &inputs) const { + // const auto &input = inputs[0], &bias = inputs[1]; + int on = n, of = f; + int oh = (h - 1) * sh - 2 * ph + dh * (r - 1) + 1; + int ow = (w - 1) * sw - 2 * pw + dw * (s - 1) + 1; + + return {{{on, oh, ow, of}}}; +} +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_conv.cc b/test/kernels/cuda/test_cuda_conv.cc index 657ecd17..8ad466e7 100644 --- a/test/kernels/cuda/test_cuda_conv.cc +++ b/test/kernels/cuda/test_cuda_conv.cc @@ -43,6 +43,42 @@ void testConvCudnn( gCuda->print(); } +void testConvNHWCCudnn( + const std::function &generator, + vector ansVec) { + // Construct Runtime and graph for CPU and CUDA + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 4, 4, 3}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(generator); + w0Cpu->setData(generator); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 2, 1, 1, 2); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + cuda->run(gCuda); + // copy output from CUDA to CPU + auto o0Cpu = gCpu->cloneTensor(conv->getOutput()); + o0Cpu->print(); + o0Cpu->printData(); + // check results on CPU + EXPECT_TRUE(o0Cpu->equalData(ansVec)); + // print a tensor/operator/graph by print() + gCuda->print(); +} + TEST(cuDNN_Conv, run) { testConvCudnn(OneGenerator(), vector{12, 12, 18, 18, 12, 12, 18, 18}); @@ -51,6 +87,14 @@ TEST(cuDNN_Conv, run) { vector{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); } +TEST(cuDNN_Conv, runNHWC) { + testConvNHWCCudnn(OneGenerator(), + vector{12., 12., 12., 12., 18., 18., 18., 18.}); + testConvNHWCCudnn( + IncrementalGenerator(), + vector{3350, 7562, 2306, 5546, 9480, 24546, 7185, 20793}); +} + TEST(cuDNN_Conv, tune) { Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); diff --git a/test/nnet/run_models_nnet.py b/test/nnet/run_models_nnet.py index 5aae2166..2a730f02 100644 --- a/test/nnet/run_models_nnet.py +++ b/test/nnet/run_models_nnet.py @@ -30,7 +30,7 @@ def run_and_evaluate(runtime, g): runtime.run(g, True) print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}') print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}') - print(f'Cuda graph time = {runtime.timeWithCudaGraph(g)}') + print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 100)}') def run_graph_get_output_as_torch_tensor(runtime, g): @@ -111,6 +111,23 @@ def construct_conv(runtime, n, c, h, w, f, r, s, pad, stride, dilation): handler.conv(input, w, None, pad, pad, stride, stride, dilation, dilation) return handler.getGraph() +def construct_conv_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dilation): + handler = ft.GraphHandler(runtime) + # input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input) + # w = handler.tensor([12, 56, 1, 1], tensor_type=ft.TensorType.Initialized) + # handler.conv(input, w, None, 0, 0, 1, 1, 1, 1) + input = handler.tensor([n, h, w, c], tensor_type=ft.TensorType.Input) + w = handler.tensor([f, r, s, c], tensor_type=ft.TensorType.Initialized) + handler.convNHWC(input, w, None, pad, pad, stride, stride, dilation, dilation) + return handler.getGraph() + +def construct_convtranposed_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dilation): + handler = ft.GraphHandler(runtime) + input = handler.tensor([n, h, w, c], tensor_type=ft.TensorType.Input) + w = handler.tensor([f, r, s, c], tensor_type=ft.TensorType.Initialized) + handler.convtransposed2dNHWC(input, w, None, pad, pad, stride, stride, dilation, dilation) + return handler.getGraph() + def export_op_level_onnx(runtime): graphs = [ @@ -134,9 +151,13 @@ if __name__ == "__main__": # (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 3, 1, 1, 1), 'conv3x3'), # FSRCNN Conv_4 3x3 # ft.getGANGraph(batch, runtime, 5, 1) # (ft.getLongformer(runtime, 1), 'longformer.bs1'), - (ft.getLongformer(runtime, 16), 'longformer.bs16'), + # (ft.getLongformer(runtime, 16), 'longformer.bs16'), # construct_convTranspose2d(runtime) # (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'), + # (ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"), + # (ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"), + # (construct_conv_nhwc(runtime, 1, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1') + (load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'), ] for original_g, name in graphs: