From 1aefc1b27ef7b846768877b7106227e71e96363b Mon Sep 17 00:00:00 2001 From: zhengly123 Date: Tue, 27 Sep 2022 10:41:12 +0800 Subject: [PATCH] Add python interface for CUDA operator evaluation (#42) * Refactor: seperate data generator * Add: python bindings for opTimer * Fix: test_perfengine Co-authored-by: Liyan Zheng --- CMakeLists.txt | 15 +++--- include/cuda/operator_timer.h | 11 ++++ include/test.h | 56 +------------------- include/utils/data_generator.h | 57 ++++++++++++++++++++ python/infinitensor/operator_timer.py | 13 +++++ src/cuda/operator_timer.cc | 76 +++++++++++++++++++++++++++ src/ffi/ffi_infinitensor.cc | 22 ++++++++ test/kernels/cuda/test_perfengine.cc | 47 +++++++---------- 8 files changed, 206 insertions(+), 91 deletions(-) create mode 100644 include/cuda/operator_timer.h create mode 100644 include/utils/data_generator.h create mode 100644 python/infinitensor/operator_timer.py create mode 100644 src/cuda/operator_timer.cc create mode 100644 src/ffi/ffi_infinitensor.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 63762f62..f760f934 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,11 @@ endif() target_link_libraries(InfiniTensor pybind11::embed) +# Python bindings +file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc) +pybind11_add_module(pyinfinitensor MODULE ${FFIS}) +target_link_libraries(pyinfinitensor PRIVATE InfiniTensor) + if(USE_BACKTRACE) add_definitions(-D BACKWARD_TRACE) add_subdirectory(3rd-party/backward-cpp) @@ -103,6 +108,7 @@ if(USE_BACKTRACE) endif() if(USE_CUDA) + add_compile_definitions(USE_CUDA=1) # Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER} @@ -111,14 +117,7 @@ if(USE_CUDA) set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80") enable_language(CUDA) find_package(CUDAToolkit) # For nvrtc and cuda driver - target_link_libraries( - InfiniTensor - cudnn - CUDA::curand - CUDA::cublas - CUDA::nvrtc - CUDA::cudart - CUDA::cuda_driver) + target_link_libraries(InfiniTensor cudnn CUDA::curand CUDA::cublas CUDA::nvrtc CUDA::cudart CUDA::cuda_driver) endif() if(USE_BANG) diff --git a/include/cuda/operator_timer.h b/include/cuda/operator_timer.h new file mode 100644 index 00000000..1b9f4c1a --- /dev/null +++ b/include/cuda/operator_timer.h @@ -0,0 +1,11 @@ +#pragma once +namespace infini { +namespace opTimer { +double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s, + int padh, int padw, int strideh, int stridew, + int dilationh, int dilationw, int group, + const char *name); + +double getPerfMatmulCublas(int b, int m, int n, int k, const char *name); +} // namespace opTimer +} // namespace infini \ No newline at end of file diff --git a/include/test.h b/include/test.h index 052b6abd..762abfa4 100644 --- a/include/test.h +++ b/include/test.h @@ -1,59 +1,5 @@ #pragma once #include "core/common.h" #include "core/tensor_base.h" +#include "utils/data_generator.h" #include "gtest/gtest.h" - -namespace infini { - -// TODO: isolate these class -class DataGenerator { - private: - virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); } - virtual void fill(float *data, size_t size) { IT_TODO_HALT(); } - - public: - virtual ~DataGenerator() {} - void operator()(void *data, size_t size, DataType dataType) { - if (dataType == DataType::UInt32) - fill(reinterpret_cast(data), size); - else if (dataType == DataType::Float32) - fill(reinterpret_cast(data), size); - else - IT_TODO_HALT(); - } -}; - -class IncrementalGenerator : public DataGenerator { - public: - virtual ~IncrementalGenerator() {} - - private: - template void fill(T *data, size_t size) { - for (size_t i = 0; i < size; i++) { - data[i] = i; - } - } - - void fill(uint32_t *data, size_t size) override { - fill(data, size); - } - void fill(float *data, size_t size) override { fill(data, size); } -}; - -class OneGenerator : public DataGenerator { - public: - virtual ~OneGenerator() {} - - private: - template void fill(T *data, size_t size) { - for (size_t i = 0; i < size; i++) { - data[i] = 1; - } - } - - void fill(uint32_t *data, size_t size) override { - fill(data, size); - } - void fill(float *data, size_t size) override { fill(data, size); } -}; -} // namespace infini \ No newline at end of file diff --git a/include/utils/data_generator.h b/include/utils/data_generator.h new file mode 100644 index 00000000..6ae8266e --- /dev/null +++ b/include/utils/data_generator.h @@ -0,0 +1,57 @@ +#include "core/common.h" +#include "core/tensor_base.h" + +namespace infini { + +// TODO: isolate these class +class DataGenerator { + private: + virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); } + virtual void fill(float *data, size_t size) { IT_TODO_HALT(); } + + public: + virtual ~DataGenerator() {} + void operator()(void *data, size_t size, DataType dataType) { + if (dataType == DataType::UInt32) + fill(reinterpret_cast(data), size); + else if (dataType == DataType::Float32) + fill(reinterpret_cast(data), size); + else + IT_TODO_HALT(); + } +}; + +class IncrementalGenerator : public DataGenerator { + public: + virtual ~IncrementalGenerator() {} + + private: + template void fill(T *data, size_t size) { + for (size_t i = 0; i < size; i++) { + data[i] = i; + } + } + + void fill(uint32_t *data, size_t size) override { + fill(data, size); + } + void fill(float *data, size_t size) override { fill(data, size); } +}; + +class OneGenerator : public DataGenerator { + public: + virtual ~OneGenerator() {} + + private: + template void fill(T *data, size_t size) { + for (size_t i = 0; i < size; i++) { + data[i] = 1; + } + } + + void fill(uint32_t *data, size_t size) override { + fill(data, size); + } + void fill(float *data, size_t size) override { fill(data, size); } +}; +} // namespace infini \ No newline at end of file diff --git a/python/infinitensor/operator_timer.py b/python/infinitensor/operator_timer.py new file mode 100644 index 00000000..ce338bc4 --- /dev/null +++ b/python/infinitensor/operator_timer.py @@ -0,0 +1,13 @@ +from tokenize import Double +import pyinfinitensor # import getPerfConv, getPerfMatmul + + +def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name): + return pyinfinitensor.getPerfConvCudnn(n, c, h, w, f, r, s, padh, padw, + strideh, stridew, dilationh, dilationw, group, name) + + +def getPerfMatmul(b, m, n, k, name): + return pyinfinitensor.getPerfMatmulCublas(b, m, n, k, name) + + diff --git a/src/cuda/operator_timer.cc b/src/cuda/operator_timer.cc new file mode 100644 index 00000000..8fbcaae9 --- /dev/null +++ b/src/cuda/operator_timer.cc @@ -0,0 +1,76 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/conv.h" +#include "operators/matmul.h" +#include "utils/data_generator.h" + +namespace infini { +namespace opTimer { + +double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s, + int padh, int padw, int strideh, int stridew, + int dilationh, int dilationw, int group, + const char *name) { + // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, + // dilationh, dilationw, group] = + // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({n, c, h, w}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({f, c, r, s}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr, padh, padw, + strideh, stridew, dilationh, dilationw); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); + return cuda->getPerfTime(gCuda); +} + +double getPerfMatmulCublas(int b, int m, int n, int k, const char *name) { + // const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew, + // dilationh, dilationw, group] = + // tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1}; + Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + Runtime cuda = make_ref(); + Graph gCuda = make_ref(cuda); + // Set input data on CPU in a CPU Graph + Tensor i0Cpu = gCpu->addTensor({b, m, k}, DataType::Float32); + Tensor w0Cpu = gCpu->addTensor({b, k, n}, DataType::Float32); + // Malloc data for all tensors in a graph. Do we need implicit allocation? + gCpu->dataMalloc(); + i0Cpu->setData(IncrementalGenerator()); + w0Cpu->setData(IncrementalGenerator()); + + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); + Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); + // Build CUDA graph + auto conv = gCuda->addOp(i0Cuda, w0Cuda, nullptr); + // allocate CUDA memory + gCuda->dataMalloc(); + // Execute on CUDA + bool tune = true; + cuda->run(gCuda, tune); + return cuda->getPerfTime(gCuda); +} + +} // namespace opTimer +} // namespace infini \ No newline at end of file diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc new file mode 100644 index 00000000..87338afa --- /dev/null +++ b/src/ffi/ffi_infinitensor.cc @@ -0,0 +1,22 @@ +#include +#ifdef USE_CUDA +#include "cuda/operator_timer.h" +#endif +namespace py = pybind11; + +namespace infini { + +using namespace py::literals; +using policy = py::return_value_policy; + +void register_operator_timer(py::module &m) { +#ifdef USE_CUDA + using namespace opTimer; + m.def("getPerfConvCudnn", &getPerfConvCudnn); + m.def("getPerfMatmulCublas", &getPerfMatmulCublas); +#endif +} + +} // namespace infini + +PYBIND11_MODULE(pyinfinitensor, m) { infini::register_operator_timer(m); } \ No newline at end of file diff --git a/test/kernels/cuda/test_perfengine.cc b/test/kernels/cuda/test_perfengine.cc index 99559734..52b9e3b3 100644 --- a/test/kernels/cuda/test_perfengine.cc +++ b/test/kernels/cuda/test_perfengine.cc @@ -14,36 +14,27 @@ TEST(PerfEngine, save_and_load) { Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton Graph gCpu = make_ref(cpu); Runtime cuda = make_ref(); - Graph gCuda = make_ref(cuda); - // Set input data on CPU in a CPU Graph - Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32); - Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32); - // Malloc data for all tensors in a graph. Do we need implicit allocation? - gCpu->dataMalloc(); - i0Cpu->setData(IncrementalGenerator()); - w0Cpu->setData(IncrementalGenerator()); + { // Conv + Graph gCuda = make_ref(cuda); - // Copy input tensors from CPU to CUDA - Tensor i0Cuda = gCuda->cloneTensor(i0Cpu); - Tensor w0Cuda = gCuda->cloneTensor(w0Cpu); - // Build CUDA graph - auto conv = - gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1); + // Copy input tensors from CPU to CUDA + Tensor i0Cuda = gCuda->addTensor({1, 3, 224, 224}, DataType::Float32); + Tensor w0Cuda = gCuda->addTensor({2, 3, 3, 3}, DataType::Float32); + // Build CUDA graph + auto conv = + gCuda->addOp(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1); + gCuda->dataMalloc(); + cuda->run(gCuda, true); + } - auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32); - auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32); - gCpu->dataMalloc(); - ACpu->setData(IncrementalGenerator()); - BCpu->setData(IncrementalGenerator()); - - auto cudaRuntime = make_ref(); - - auto ACuda = gCuda->cloneTensor(ACpu); - auto BCuda = gCuda->cloneTensor(BCpu); - auto matmul = gCuda->addOp(ACuda, BCuda, nullptr); - - gCuda->dataMalloc(); - cudaRuntime->run(gCuda, true); + { // Matmul + Graph gCuda = make_ref(cuda); + auto ACuda = gCuda->addTensor(Shape{1, 3, 5}, DataType::Float32); + auto BCuda = gCuda->addTensor(Shape{1, 5, 2}, DataType::Float32); + auto matmul = gCuda->addOp(ACuda, BCuda, nullptr); + gCuda->dataMalloc(); + cuda->run(gCuda, true); + } auto &perfEngine = PerfEngine::getInstance(); json j0 = perfEngine;