diff --git a/CMakeLists.txt b/CMakeLists.txt index 8220b3ea..ddb7f8e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -Wno-error=deprecated-declarations -Wno-error=sign-compare") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion diff --git a/include/core/common.h b/include/core/common.h index 7effff81..c9b8fe81 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -71,8 +71,9 @@ template std::string vecToString(const std::vector &vec) { return ret; } -double timeit(const std::function &func, int warmupRounds = 200, - int timingRounds = 200, - const std::function &sync = {}); +double timeit( + const std::function &func, + const std::function &sync = []() {}, int warmupRounds = 200, + int timingRounds = 200); } // namespace infini diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index ee625699..ac2c3092 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -21,6 +21,7 @@ class CudaRuntimeObj : public RuntimeObj { workspace = alloc(workspaceSize); } virtual ~CudaRuntimeObj() { + dealloc(workspace); checkCudnnError(cudnnDestroy(cudnn)); checkCublasError(cublasDestroy(cublas)); } @@ -57,6 +58,6 @@ class CudaRuntimeObj : public RuntimeObj { } private: - void runWithoutSync(const Graph &graph) const; + void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; }; } // namespace infini \ No newline at end of file diff --git a/src/core/common.cc b/src/core/common.cc index 43192441..71f39244 100644 --- a/src/core/common.cc +++ b/src/core/common.cc @@ -4,8 +4,9 @@ namespace infini { -double timeit(const std::function &func, int warmupRounds, - int timingRounds, const std::function &sync) { +double timeit(const std::function &func, + const std::function &sync, int warmupRounds, + int timingRounds) { for (int i = 0; i < warmupRounds; ++i) func(); if (sync) diff --git a/src/core/runtime.cc b/src/core/runtime.cc index a97ef48b..93bd1ea7 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -40,17 +40,18 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { // Tune the kernel if there is no record if (!perfData) { // TODO: record is not used + // printf("no record data\n"); record = kernel->tune(op, this); perfEngine.setPerfData(perfKey, record); } else record = *perfData; if (!profiling) { - kernel->compute(op, *perfData, this); + kernel->compute(op, record, this); continue; } else { - double t = - timeit([&]() { kernel->compute(op, *perfData, this); }, 1, 1); + double t = timeit([&]() { kernel->compute(op, record, this); }, + []() {}, 1, 1); op->print(); printf(" op_time %lf\n", t); totalTime += t; diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 1e91db50..78249165 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -4,10 +4,13 @@ namespace infini { -void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { +void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, + bool profiling = false) const { const auto &kernelRegistry = KernelRegistry::getInstance(); auto perfEngine = PerfEngine::getInstance(); - + double totalTime = 0; + std::map opTime; + std::map opCnt; for (auto &op : graph->getOperators()) { // HACK: set correct data type auto kernelAttrs = @@ -15,17 +18,40 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; std::optional perfData = perfEngine.getPerfData(perfKey); - if (perfData) - kernel->compute(op, *perfData, this); - else + if (!perfData && !tune) { kernel->compute(op, this); + continue; + } + + PerfRecord record; + + if (!perfData) { + record = kernel->tune(op, this); + perfEngine.setPerfData(perfKey, record); + + } else + record = *perfData; + + double t = record.time; + totalTime += t; + + if (profiling) { + double t = timeit([&]() { kernel->compute(op, record, this); }, + [&]() { sync(); }, 1, 1); + op->print(); + printf(" op_time on cuda %lf\n", t); + totalTime += t; + opTime[op->getOpType()] += t; + opCnt[op->getOpType()]++; + } } } void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { - if (tune || profiling) + if (profiling) IT_TODO_HALT(); - runWithoutSync(graph); + + runWithoutSync(graph, tune); sync(); } diff --git a/src/kernels/cuda/conv.cc b/src/kernels/cuda/conv.cc index c31868bb..a3c9e7ad 100644 --- a/src/kernels/cuda/conv.cc +++ b/src/kernels/cuda/conv.cc @@ -1,7 +1,10 @@ #include "operators/conv.h" #include "core/kernel.h" #include "cuda/cuda_runtime.h" - +#include +#include +#include +#include namespace infini { static constexpr int N_ALGO = 8; @@ -14,18 +17,25 @@ static constexpr cudnnConvolutionFwdAlgo_t ALGOS[N_ALGO] = { CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; +static constexpr int N_MODE = 2; +static constexpr cudnnConvolutionMode_t MODES[N_MODE] = { + CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION}; struct ConvCuDnnPerfRecord : public PerfRecord { int algo = 0; // cudnnConvolutionFwdAlgo_t + int mode = 1; size_t workspaceSize = 100000; bool fuseAct = false; }; class convCudnn : public Kernel { - bool cuDNNUnfused(const Ref &op, const ConvCuDnnPerfRecord &record, - const CudaRuntimeObj *context) const { - cudnnStatus_t stat; + std::tuple + cuDNNDescriptorAccess(const Ref &op, + const ConvCuDnnPerfRecord &record) const { void *const inData = (op->getInputs(0)->getRawDataPtr()); void *const knData = (op->getInputs(1)->getRawDataPtr()); if (op->getInputs().size() > 2) // Bias is not supported yet @@ -52,7 +62,6 @@ class convCudnn : public Kernel { checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, f, channelsPerGrp, r, s)); - // get bias cudnnTensorDescriptor_t biasDesc; checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc)); @@ -64,7 +73,7 @@ class convCudnn : public Kernel { checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc)); // TODO: CUDNN_CONVOLUTION is a tunable argument checkCudnnError(cudnnSetConvolution2dDescriptor( - convDesc, ph, pw, sh, sw, dh, dw, CUDNN_CROSS_CORRELATION, + convDesc, ph, pw, sh, sw, dh, dw, MODES[record.mode], CUDNN_DATA_FLOAT)); if (g > 1) { checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g)); @@ -93,7 +102,6 @@ class convCudnn : public Kernel { assert(false); } - // get outputs int outn, outc, outh, outw; checkCudnnError(cudnnGetConvolution2dForwardOutputDim( convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw)); @@ -106,6 +114,15 @@ class convCudnn : public Kernel { op->getOutput()->getDims(), "cuDNN output shape mismatches with OP output shape"); + return tuple(inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc); + } + bool cuDNNUnfused(const Ref &op, const ConvCuDnnPerfRecord &record, + const CudaRuntimeObj *context) const { + cudnnStatus_t stat; + + auto [inData, knData, outData, inDesc, knDesc, biasDesc, convDesc, + actDesc, outDesc] = cuDNNDescriptorAccess(op, record); // get workspace size_t wsSize = record.workspaceSize; stat = cudnnGetConvolutionForwardWorkspaceSize( @@ -113,9 +130,7 @@ class convCudnn : public Kernel { ALGOS[record.algo], &wsSize); if (stat != CUDNN_STATUS_SUCCESS) return false; - // assert(wsSize < (size_t)3 * 1024 * 1024 * 1024); - // if (wsSize >= (size_t)10 * 1024 * 1024 * 1024) - // continue; + CudaPtr wsData = context->getWorkspace(wsSize); float alpha = 1.f, beta = 0.f; @@ -172,10 +187,12 @@ class convCudnn : public Kernel { // Destories in CUDA does not require sync. But cuDNN does not state // whether sync is required before destories. - checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); - checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); + checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); + checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); return true; } @@ -186,9 +203,63 @@ class convCudnn : public Kernel { PerfRecord tune(const Operator &_op, const RuntimeObj *_context) const override { - // TODO: real tuning - ConvCuDnnPerfRecord ret; - ret.time = timeit([&]() { compute(_op, _context); }); + ConvCuDnnPerfRecord ret, tmp_ret; + ret.time = std::numeric_limits::max(); + auto context = dynamic_cast(_context); + auto op = as(_op); + // Try every possible data input mode of convolution func + for (int i = 0; i < N_MODE; i++) { + // Try every possible algorithm of convolution func + for (int j = 0; j < N_ALGO; j++) { + tmp_ret.algo = j; + tmp_ret.mode = i; + // Check if the kernel supports the op + cudnnStatus_t stat; + auto [inData, knData, outData, inDesc, knDesc, biasDesc, + convDesc, actDesc, outDesc] = + cuDNNDescriptorAccess(op, tmp_ret); + + // get workspace + size_t wsSize = tmp_ret.workspaceSize; + stat = cudnnGetConvolutionForwardWorkspaceSize( + context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc, + ALGOS[tmp_ret.algo], &wsSize); + if (stat != CUDNN_STATUS_SUCCESS) + continue; + + CudaPtr wsData = context->getWorkspace(wsSize); + float alpha = 1.f, beta = 0.f; + + stat = cudnnConvolutionForward( + context->cudnnHandle(), &alpha, inDesc, inData, knDesc, + knData, convDesc, ALGOS[tmp_ret.algo], wsData, wsSize, + &beta, outDesc, outData); + if (stat != CUDNN_STATUS_SUCCESS) + continue; + tmp_ret.time = timeit( + [&]() { + cudnnConvolutionForward( + context->cudnnHandle(), &alpha, inDesc, inData, + knDesc, knData, convDesc, ALGOS[tmp_ret.algo], + wsData, wsSize, &beta, outDesc, outData); + }, + [&]() { context->sync(); }); + printf("mode:%d algo:%d :%.8lf\n", i, j, tmp_ret.time); + // Update the tune result + if (ret.time > tmp_ret.time) + ret = tmp_ret; + + checkCudnnError(cudnnDestroyTensorDescriptor(outDesc)); + checkCudnnError(cudnnDestroyActivationDescriptor(actDesc)); + checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc)); + checkCudnnError(cudnnDestroyFilterDescriptor(knDesc)); + checkCudnnError(cudnnDestroyTensorDescriptor(inDesc)); + } + } + // Test infomation output + printf("the best algo is %d, the best conv mode is %d\n", ret.algo, + ret.mode); return ret; } diff --git a/test/operators/test_conv.cc b/test/operators/test_conv.cc index f6397c5c..db23cb73 100644 --- a/test/operators/test_conv.cc +++ b/test/operators/test_conv.cc @@ -1,8 +1,10 @@ #include "core/graph.h" +#include "core/kernel.h" #include "core/runtime.h" #include "cuda/cuda_runtime.h" #include "cuda/cuda_utility.h" #include "operators/conv.h" + #include "test.h" namespace infini { @@ -107,4 +109,30 @@ TEST(Conv, cuDNN) { IncrementalGenerator(), vector{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656}); } + +TEST(Conv, tune) { + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({1, 3, 800, 800}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({2, 3, 5, 5}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); +} } // namespace infini \ No newline at end of file diff --git a/test/script/env_lotus.sh b/test/script/env_lotus.sh index a65a746d..d9c2e170 100644 --- a/test/script/env_lotus.sh +++ b/test/script/env_lotus.sh @@ -1,2 +1,3 @@ +. /home/spack/spack/share/spack/setup-env.sh spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc