From 1152adc94a7159929bd2a12a9c0f8f5fa966e6d7 Mon Sep 17 00:00:00 2001 From: zhengly123 Date: Fri, 7 Oct 2022 16:03:11 +0800 Subject: [PATCH] Add: python API for timing ConvTranspose (#46) * Add: python interfaced for timing operators * Fix: CUDA Runtime run Co-authored-by: Liyan Zheng --- include/core/common.h | 2 +- include/core/graph.h | 2 +- include/core/kernel.h | 6 ++++- include/core/runtime.h | 2 ++ include/cuda/cuda_runtime.h | 5 +++- include/cuda/operator_timer.h | 5 ++++ python/infinitensor/operator_timer.py | 12 +++++---- src/core/runtime.cc | 2 ++ src/cuda/cuda_runtime.cc | 36 ++++++++++++++++++------- src/cuda/operator_timer.cc | 38 ++++++++++++++++++++++++++- src/ffi/ffi_infinitensor.cc | 1 + src/kernels/cuda/conv_transposed.cc | 3 +-- 12 files changed, 93 insertions(+), 21 deletions(-) diff --git a/include/core/common.h b/include/core/common.h index 6bdb92a3..222b8060 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -44,7 +44,7 @@ using HashType = uint64_t; // compatible with std::hash ? void(0) \ : throw ::infini::Exception( \ std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \ - "] Assertion failed (" + #name + "): " + #info)) + "] Assertion failed (" + #name + "): " + info)) #define _IT_ASSERT_1(name) _IT_ASSERT_2(name, ""); #define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__) diff --git a/include/core/graph.h b/include/core/graph.h index e0948539..f5ae6fa7 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -16,7 +16,7 @@ class GraphObj : public Object { GraphObj(Runtime runtime) : runtime(runtime){}; string toString() const override; - Tensor addTensor(Shape dim, DataType dtype = DataType::UInt32); + Tensor addTensor(Shape dim, DataType dtype = DataType::Float32); Tensor cloneTensor(const Tensor &tensor) { auto ret = addTensor(tensor->getDims(), tensor->getDType()); ret->dataMalloc(); diff --git a/include/core/kernel.h b/include/core/kernel.h index 9415d496..8268792a 100644 --- a/include/core/kernel.h +++ b/include/core/kernel.h @@ -102,7 +102,11 @@ class KernelRegistry { } Kernel *getKernel(const KernelAttrs &kernelAttrs) const { auto it = kernels.find(kernelAttrs); - IT_ASSERT(it != kernels.end(), "Kernel not found."); + IT_ASSERT(it != kernels.end(), + "Kernel not found for key {" + + to_string(enum_to_underlying(std::get<0>(kernelAttrs))) + + ", " + OpRegistry::getOpName(std::get<1>(kernelAttrs)) + + ", " + std::get<2>(kernelAttrs).toString()); return std::get<0>(it->second); } const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const { diff --git a/include/core/runtime.h b/include/core/runtime.h index f36f4ac8..6aa3c4fe 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -71,6 +71,7 @@ class RuntimeObj : public std::enable_shared_from_this { size_t bytes) const = 0; virtual void copyBlobToCPU(void *dst, const void *src, size_t bytes) const = 0; + virtual string toString() const = 0; protected: void printProfilingData(double totTime, @@ -102,6 +103,7 @@ class CpuRuntimeObj : public RuntimeObj { void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override; void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const override; + string toString() const override; }; } // namespace infini diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index 305efc62..efb9a9e2 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -34,6 +34,7 @@ class CudaRuntimeObj : public RuntimeObj { checkCublasError(cublasDestroy(cublas)); checkCUresult(cuCtxDestroy(newContext)); } + string toString() const override; void run(const Graph &graph, bool tune = false, bool profiling = false) const; @@ -68,7 +69,9 @@ class CudaRuntimeObj : public RuntimeObj { checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice)); } + void runWithoutSync(const Graph &graph) const; + private: - void runWithoutSync(const Graph &graph, bool tune, bool profiling) const; + void tune(const Graph &graph, bool profiling) const; }; } // namespace infini \ No newline at end of file diff --git a/include/cuda/operator_timer.h b/include/cuda/operator_timer.h index 1b9f4c1a..b9d49e13 100644 --- a/include/cuda/operator_timer.h +++ b/include/cuda/operator_timer.h @@ -6,6 +6,11 @@ double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s, int dilationh, int dilationw, int group, const char *name); +double getPerfConvTransposed2dCudnn(int n, int c, int h, int w, int f, int r, + int s, int padh, int padw, int strideh, + int stridew, int dilationh, int dilationw, + int oph, int opw, int group); + double getPerfMatmulCublas(int b, int m, int n, int k, const char *name); } // namespace opTimer } // namespace infini \ No newline at end of file diff --git a/python/infinitensor/operator_timer.py b/python/infinitensor/operator_timer.py index ce338bc4..52c776fa 100644 --- a/python/infinitensor/operator_timer.py +++ b/python/infinitensor/operator_timer.py @@ -2,12 +2,14 @@ from tokenize import Double import pyinfinitensor # import getPerfConv, getPerfMatmul -def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name): +def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""): return pyinfinitensor.getPerfConvCudnn(n, c, h, w, f, r, s, padh, padw, - strideh, stridew, dilationh, dilationw, group, name) + strideh, stridew, dilationh, dilationw, group, name) -def getPerfMatmul(b, m, n, k, name): +def getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group): + return pyinfinitensor.getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group) + + +def getPerfMatmul(b, m, n, k, name=""): return pyinfinitensor.getPerfMatmulCublas(b, m, n, k, name) - - diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 28f243fd..0e531771 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -139,4 +139,6 @@ void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, const void *src, memcpy(dst, src, bytes); } +string CpuRuntimeObj::toString() const { return "CPU Runtime"; } + } // namespace infini diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index e8369535..ad6616bb 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -5,8 +5,25 @@ #include "operators/matmul.h" namespace infini { -void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, - bool profiling = false) const { +void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { + const auto &kernelRegistry = KernelRegistry::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); + for (auto &op : graph->getOperators()) { + // HACK: set correct data type + auto kernelAttrs = + KernelAttrs{device, op->getOpType(), DataType::Float32}; + Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; + auto perfData = perfEngine.getPerfData(perfKey); + // IT_ASSERT(perfData, "No perf data for OP " + op->toString()); + if (perfData) + kernel->compute(op, perfData, this); + else + kernel->compute(op, this); + } +} + +void CudaRuntimeObj::tune(const Graph &graph, bool profiling = false) const { const auto &kernelRegistry = KernelRegistry::getInstance(); auto &perfEngine = PerfEngine::getInstance(); double totalTime = 0; @@ -19,11 +36,6 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; auto perfData = perfEngine.getPerfData(perfKey); - if (!perfData && !tune) { - kernel->compute(op, this); - continue; - } - PerfRecord record; if (!perfData) { record = kernel->tune(op, this); @@ -46,13 +58,19 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false, } } -void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const { +void CudaRuntimeObj::run(const Graph &graph, bool runTune, + bool profiling) const { if (profiling) IT_TODO_HALT(); - runWithoutSync(graph, tune, profiling); + if (runTune) + tune(graph, profiling); + else + runWithoutSync(graph); sync(); } void CudaRuntimeObj::sync() const { cudaDeviceSynchronize(); } +string CudaRuntimeObj::toString() const { return "CUDA Runtime"; } + } // namespace infini \ No newline at end of file diff --git a/src/cuda/operator_timer.cc b/src/cuda/operator_timer.cc index 8fbcaae9..34241e27 100644 --- a/src/cuda/operator_timer.cc +++ b/src/cuda/operator_timer.cc @@ -22,8 +22,9 @@ double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s, Runtime cuda = make_ref(); Graph gCuda = make_ref(cuda); // Set input data on CPU in a CPU Graph + IT_ASSERT(c % group == 0); Tensor i0Cpu = gCpu->addTensor({n, c, h, w}, DataType::Float32); - Tensor w0Cpu = gCpu->addTensor({f, c, r, s}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({f, c / group, r, s}, DataType::Float32); // Malloc data for all tensors in a graph. Do we need implicit allocation? gCpu->dataMalloc(); i0Cpu->setData(IncrementalGenerator()); @@ -43,6 +44,41 @@ double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s, return cuda->getPerfTime(gCuda); } +double getPerfConvTransposed2dCudnn(int n, int c, int h, int w, int f, int r, + int s, int padh, int padw, int strideh, + int stridew, int dilationh, int dilationw, + int oph, int opw, int group) { + // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, + // dilationh, dilationw, group] = + // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + IT_ASSERT(c % group == 0); + Tensor i0Cpu = gCpu->addTensor({n, f, h, w}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({f, c / group, r, s}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gCuda->addOp( + i0Cuda, w0Cuda, nullptr, padh, padw, strideh, stridew, dilationh, + dilationw, oph, opw, group); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); + return cuda->getPerfTime(gCuda); +} + double getPerfMatmulCublas(int b, int m, int n, int k, const char *name) { // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, // dilationh, dilationw, group] = diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 87338afa..6df1895a 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -13,6 +13,7 @@ void register_operator_timer(py::module &m) { #ifdef USE_CUDA using namespace opTimer; m.def("getPerfConvCudnn", &getPerfConvCudnn); + m.def("getPerfConvTransposed2dCudnn", &getPerfConvTransposed2dCudnn); m.def("getPerfMatmulCublas", &getPerfMatmulCublas); #endif } diff --git a/src/kernels/cuda/conv_transposed.cc b/src/kernels/cuda/conv_transposed.cc index 6f379006..c20eb3fa 100644 --- a/src/kernels/cuda/conv_transposed.cc +++ b/src/kernels/cuda/conv_transposed.cc @@ -250,8 +250,7 @@ class convBackwardDataCudnn : public Kernel { outData); }, [&]() { context->sync(); }); - // printf("mode:%d algo:%d :%.8lf\n", mode, algo, - // record.time); + // printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time); // Update the tune result if (ret.time > record.time)