cuDNN conv tuning (#16)

* Function tune and corresponding testcase.

*Add: Tune function in /src/kernel/cuda/conv.cc and corresponding testcase in test_conv.

*Fix: A little bug of perfRecord using in /src/core/runtime.cc.

* Tune part debug

*Add: recover the code, fixed the commit error.

*Add: some anotations in tune function

* clang formmat test

* Fix: mem leak in CUDA Runtime and Conv

* Fix: sync in conv and default sync in timeit

* Change the way to tune operator conv.

Timeit function cudNNUnfused -> Timeit function cudnnConvolutionForward.

* Change: merge the common part of cudnnunfused&tune into cudnndescriptoraccess

* clang test

* clang-format

* clang-format bash.

* Chore: remove print and blank lines

Co-authored-by: wcz112 <wcz19@mails.tsinghua.edu.cn>
Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
Anmuliar 2022-08-29 21:37:07 +08:00 committed by GitHub
parent e076991f2f
commit bd63f738dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 162 additions and 36 deletions

View File

@ -71,8 +71,9 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
return ret;
}
double timeit(const std::function<void()> &func, int warmupRounds = 200,
int timingRounds = 200,
const std::function<void(void)> &sync = {});
double timeit(
const std::function<void()> &func,
const std::function<void(void)> &sync = []() {}, int warmupRounds = 200,
int timingRounds = 200);
} // namespace infini

View File

@ -21,6 +21,7 @@ class CudaRuntimeObj : public RuntimeObj {
workspace = alloc(workspaceSize);
}
virtual ~CudaRuntimeObj() {
dealloc(workspace);
checkCudnnError(cudnnDestroy(cudnn));
checkCublasError(cublasDestroy(cublas));
}
@ -57,6 +58,6 @@ class CudaRuntimeObj : public RuntimeObj {
}
private:
void runWithoutSync(const Graph &graph) const;
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
};
} // namespace infini

View File

@ -4,8 +4,9 @@
namespace infini {
double timeit(const std::function<void()> &func, int warmupRounds,
int timingRounds, const std::function<void(void)> &sync) {
double timeit(const std::function<void()> &func,
const std::function<void(void)> &sync, int warmupRounds,
int timingRounds) {
for (int i = 0; i < warmupRounds; ++i)
func();
if (sync)

View File

@ -40,17 +40,18 @@ void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
// Tune the kernel if there is no record
if (!perfData) {
// TODO: record is not used
// printf("no record data\n");
record = kernel->tune(op, this);
perfEngine.setPerfData(perfKey, record);
} else
record = *perfData;
if (!profiling) {
kernel->compute(op, *perfData, this);
kernel->compute(op, record, this);
continue;
} else {
double t =
timeit([&]() { kernel->compute(op, *perfData, this); }, 1, 1);
double t = timeit([&]() { kernel->compute(op, record, this); },
[]() {}, 1, 1);
op->print();
printf(" op_time %lf\n", t);
totalTime += t;

View File

@ -4,10 +4,13 @@
namespace infini {
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
bool profiling = false) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto perfEngine = PerfEngine::getInstance();
double totalTime = 0;
std::map<OpType, double> opTime;
std::map<OpType, int> opCnt;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
@ -15,17 +18,38 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
if (perfData)
kernel->compute(op, *perfData, this);
else
if (!perfData && !tune) {
kernel->compute(op, this);
continue;
}
PerfRecord record;
if (!perfData) {
record = kernel->tune(op, this);
perfEngine.setPerfData(perfKey, record);
} else
record = *perfData;
double t = record.time;
totalTime += t;
if (profiling) {
double t = timeit([&]() { kernel->compute(op, record, this); },
[&]() { sync(); }, 1, 1);
op->print();
printf(" op_time on cuda %lf\n", t);
totalTime += t;
opTime[op->getOpType()] += t;
opCnt[op->getOpType()]++;
}
}
}
void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
if (tune || profiling)
if (profiling)
IT_TODO_HALT();
runWithoutSync(graph);
runWithoutSync(graph, tune);
sync();
}

View File

@ -1,7 +1,10 @@
#include "operators/conv.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include <chrono>
#include <functional>
#include <limits>
#include <tuple>
namespace infini {
static constexpr int N_ALGO = 8;
@ -14,18 +17,25 @@ static constexpr cudnnConvolutionFwdAlgo_t ALGOS[N_ALGO] = {
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
static constexpr int N_MODE = 2;
static constexpr cudnnConvolutionMode_t MODES[N_MODE] = {
CUDNN_CONVOLUTION, CUDNN_CROSS_CORRELATION};
struct ConvCuDnnPerfRecord : public PerfRecord {
int algo = 0; // cudnnConvolutionFwdAlgo_t
int mode = 1;
size_t workspaceSize = 100000;
bool fuseAct = false;
};
class convCudnn : public Kernel {
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
const CudaRuntimeObj *context) const {
cudnnStatus_t stat;
std::tuple<void *, void *, void *, cudnnTensorDescriptor_t,
cudnnFilterDescriptor_t, cudnnTensorDescriptor_t,
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
cudnnTensorDescriptor_t>
createCuDNNDescriptor(const Ref<ConvObj> &op,
const ConvCuDnnPerfRecord &record) const {
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
if (op->getInputs().size() > 2) // Bias is not supported yet
@ -52,7 +62,6 @@ class convCudnn : public Kernel {
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW, f,
channelsPerGrp, r, s));
// get bias
cudnnTensorDescriptor_t biasDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
@ -64,7 +73,7 @@ class convCudnn : public Kernel {
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, CUDNN_CROSS_CORRELATION,
convDesc, ph, pw, sh, sw, dh, dw, MODES[record.mode],
CUDNN_DATA_FLOAT));
if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
@ -93,7 +102,6 @@ class convCudnn : public Kernel {
assert(false);
}
// get outputs
int outn, outc, outh, outw;
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
@ -106,16 +114,18 @@ class convCudnn : public Kernel {
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
// get workspace
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc);
}
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
const CudaRuntimeObj *context) const {
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
size_t wsSize = record.workspaceSize;
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &wsSize);
if (stat != CUDNN_STATUS_SUCCESS)
return false;
// assert(wsSize < (size_t)3 * 1024 * 1024 * 1024);
// if (wsSize >= (size_t)10 * 1024 * 1024 * 1024)
// continue;
CudaPtr wsData = context->getWorkspace(wsSize);
float alpha = 1.f, beta = 0.f;
@ -172,10 +182,12 @@ class convCudnn : public Kernel {
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
return true;
}
@ -186,9 +198,66 @@ class convCudnn : public Kernel {
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
// TODO: real tuning
ConvCuDnnPerfRecord ret;
ret.time = timeit([&]() { compute(_op, _context); });
ret.time = std::numeric_limits<double>::max();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<ConvObj>(_op);
// Both modes have the same performance. Only run cross-correlation.
for (int mode = 1; mode < 2; mode++) {
// Try every possible algorithm of convolution
for (int algo = 0; algo < N_ALGO; algo++) {
ConvCuDnnPerfRecord record;
record.mode = mode;
record.algo = algo;
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
// get workspace
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS)
continue;
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
continue;
record.time = timeit(
[&]() {
cudnnConvolutionForward(context->cudnnHandle(), &alpha,
inDesc, inData, knDesc, knData,
convDesc, ALGOS[record.algo],
wsData, record.workspaceSize,
&beta, outDesc, outData);
},
[&]() { context->sync(); });
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
// Update the tune result
if (ret.time > record.time)
ret = record;
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
}
}
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
// ret.mode);
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
return ret;
}

View File

@ -1,8 +1,10 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/conv.h"
#include "test.h"
namespace infini {
@ -107,4 +109,30 @@ TEST(Conv, cuDNN) {
IncrementalGenerator(),
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
}
TEST(Conv, tune) {
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
Graph gCpu = make_ref<GraphObj>(cpu);
Runtime cuda = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cuda);
// Set input data on CPU in a CPU Graph
Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32);
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
// Malloc data for all tensors in a graph. Do we need implicit allocation?
gCpu->dataMalloc();
i0Cpu->setData(IncrementalGenerator());
w0Cpu->setData(IncrementalGenerator());
// Copy input tensors from CPU to CUDA
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
// Build CUDA graph
auto conv =
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
// allocate CUDA memory
gCuda->dataMalloc();
// Execute on CUDA
bool tune = true;
cuda->run(gCuda, tune);
}
} // namespace infini

View File

@ -1,2 +1,3 @@
. /home/spack/spack/share/spack/setup-env.sh
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc