diff --git a/CMakeLists.txt b/CMakeLists.txt index f5fd43ce..b118cf3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -129,7 +129,7 @@ if(BUILD_TEST_EINNET) endif() # Python bindings -file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc) +file(GLOB_RECURSE FFIS src/ffi/ffi_callback.cc src/ffi/ffi_infinitensor.cc) pybind11_add_module(backend MODULE ${FFIS}) target_link_libraries(backend PRIVATE InfiniTensor) @@ -168,6 +168,7 @@ endif() if(USE_CUDA) add_compile_definitions(USE_CUDA=1) + add_compile_definitions(CUDA_API_PER_THREAD_DEFAULT_STREAM=1) # Support CUDA graph stream caputre # Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER} diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index b5830454..64a49f73 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -1,35 +1,22 @@ #pragma once #include "core/runtime.h" #include "cuda/cuda_common.h" +#include "nnet/dbg.h" namespace infini { class CudaRuntimeObj : public RuntimeObj { private: + cudaStream_t stream; cudnnHandle_t cudnn; cublasHandle_t cublas; CudaPtr workspace; size_t workspaceSize; + bool cudaGraphStatus; // Whether CUDA graph stream capture is enabled public: - CudaRuntimeObj() : RuntimeObj(Device::CUDA) { - - checkCudnnError(cudnnCreate(&cudnn)); - checkCublasError(cublasCreate(&cublas)); - // 10GB for Longformer - // size_t longformerNum = 3lu * (1 << 30); - workspaceSize = 7ll << 30; // 7 GB - workspace = alloc(workspaceSize); - } - virtual ~CudaRuntimeObj() { - try { - dealloc(workspace); - checkCudnnError(cudnnDestroy(cudnn)); - checkCublasError(cublasDestroy(cublas)); - } catch (const std::exception &e) { - std::cerr << "Error in ~CudaRuntimeObj: " << e.what() << std::endl; - } - } + CudaRuntimeObj(); + virtual ~CudaRuntimeObj(); string toString() const override; void run(const Graph &graph, bool tune = false, @@ -69,7 +56,15 @@ class CudaRuntimeObj : public RuntimeObj { void runWithoutSync(const Graph &graph) const; + bool isInCudaGraph() const { return cudaGraphStatus; } + cudaStream_t getStream() const { return stream; } + + double timeWithCudaGraph(Graph graph); + private: void tune(const Graph &graph, bool profiling) const; + + void beginCudaGraphStreamCapture(); + cudaGraphExec_t endCudaGraphStreamCapture(); }; } // namespace infini diff --git a/src/cuda/cuda_runtime.cc b/src/cuda/cuda_runtime.cc index 37b5e7cf..8e65c2e7 100644 --- a/src/cuda/cuda_runtime.cc +++ b/src/cuda/cuda_runtime.cc @@ -6,6 +6,46 @@ #include "operators/matmul.h" namespace infini { +CudaRuntimeObj::CudaRuntimeObj() + : RuntimeObj(Device::CUDA), stream(cudaStreamPerThread), + cudaGraphStatus(false) { + checkCudnnError(cudnnCreate(&cudnn)); + checkCublasError(cublasCreate(&cublas)); + checkCudnnError(cudnnSetStream(cudnn, stream)); + checkCublasError(cublasSetStream(cublas, stream)); + // 10GB for Longformer + // size_t longformerNum = 3lu * (1 << 30); + workspaceSize = 7ll << 30; // 7 GB + workspace = alloc(workspaceSize); +} + +CudaRuntimeObj::~CudaRuntimeObj() { + try { + dealloc(workspace); + checkCudnnError(cudnnDestroy(cudnn)); + checkCublasError(cublasDestroy(cublas)); + } catch (const std::exception &e) { + std::cerr << "Error in ~CudaRuntimeObj: " << e.what() << std::endl; + } +} + +void CudaRuntimeObj::beginCudaGraphStreamCapture() { + enum cudaStreamCaptureStatus pCaptureStatus; + checkCudaError(cudaStreamIsCapturing(stream, &pCaptureStatus)); + dbg(pCaptureStatus); + cudaGraphStatus = true; + checkCudaError(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal)); +} + +cudaGraphExec_t CudaRuntimeObj::endCudaGraphStreamCapture() { + cudaGraph_t cudaGraph; + cudaGraphExec_t instance; + checkCudaError(cudaStreamEndCapture(stream, &cudaGraph)); + cudaGraphStatus = false; + checkCudaError(cudaGraphInstantiate(&instance, cudaGraph, NULL, NULL, 0)); + return instance; +} + void CudaRuntimeObj::runWithoutSync(const Graph &graph) const { const auto &kernelRegistry = KernelRegistry::getInstance(); auto &perfEngine = PerfEngine::getInstance(); @@ -75,4 +115,52 @@ void CudaRuntimeObj::sync() const { checkCudaError(cudaDeviceSynchronize()); } string CudaRuntimeObj::toString() const { return "CUDA Runtime"; } +double CudaRuntimeObj::timeWithCudaGraph(Graph graph) { + const auto &kernelRegistry = KernelRegistry::getInstance(); + auto &perfEngine = PerfEngine::getInstance(); + // compile-time computable + map ctcMap = getCompileTimeComputableAttribute(graph); + vector> kernels; + bool status = graph->topo_sort(); + IT_ASSERT(status, "Topological sort failed"); + + for (auto &op : graph->getOperators()) { + // HACK: set correct data type + auto kernelAttrs = + KernelAttrs{device, op->getOpType(), DataType::Float32}; + Kernel *kernel = kernelRegistry.getKernel(kernelAttrs); + auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()}; + auto perfData = perfEngine.getPerfData(perfKey); + if (perfData) + kernel->compute(op, perfData, this); + else + kernel->compute(op, this); + // if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape) + // if (op->getOpType() == OpType::Matmul) + // if (op->getOpType() == OpType::Matmul || + // op->getOpType() == OpType::Relu + // // || op->getOpType() == OpType::MemBound + // ) + kernels.emplace_back(op, kernel, perfData); + } + for (auto &[op, kernel, perfData] : kernels) { + dbg(op); + } + + beginCudaGraphStreamCapture(); + for (auto &[op, kernel, perfData] : kernels) { + if (perfData) + kernel->compute(op, perfData, this); + else + kernel->compute(op, this); + } + auto cudaGraphInstance = endCudaGraphStreamCapture(); + return timeit( + [&, stream = getStream()]() { + checkCudaError(cudaGraphLaunch(cudaGraphInstance, stream)); + }, + [&, stream = getStream()]() { cudaStreamSynchronize(stream); }, 1000, + 1000); +} + } // namespace infini diff --git a/src/nnet/App/test_models.cc b/src/nnet/App/test_models.cc index 13b3683f..457b9c68 100644 --- a/src/nnet/App/test_models.cc +++ b/src/nnet/App/test_models.cc @@ -82,7 +82,8 @@ void printGraph(Graph g) { } } -Graph optimizeGraph(Graph g, Runtime runtime, bool tuning) { +Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) { + auto runtime = as(_runtime); Runtime cpu = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpu); @@ -145,6 +146,7 @@ Graph optimizeGraph(Graph g, Runtime runtime, bool tuning) { dbg(go0->equalData(bgo0, 1e-3)); dbg(runtime->getPerfTime(bestGraph, true)); dbg(runtime->timeNonCtcOperators(bestGraph)); + dbg(runtime->timeWithCudaGraph(bestGraph)); } dbg("Best graph"); diff --git a/test/kernels/cuda/test_cuda_runtime.cc b/test/kernels/cuda/test_cuda_runtime.cc new file mode 100644 index 00000000..a200427f --- /dev/null +++ b/test/kernels/cuda/test_cuda_runtime.cc @@ -0,0 +1,30 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/conv.h" +#include "test.h" + +namespace infini { + +TEST(TestCudaRuntime, CudaGraph) { + auto runtime = make_ref(); + Graph g = make_ref(runtime); + Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton + Graph gCpu = make_ref(cpu); + + const int n = 2, c = 256, h = 2, w = 2, f = 448, r = 3, s = 2; + auto i0 = g->addTensor({n, c, h, w}, DataType::Float32, TensorType::Input); + auto w0 = + g->addTensor({f, c, r, s}, DataType::Float32, TensorType::Initialized); + g->addOp(i0, w0, nullptr, 1, 1, 1, 1, 1, 1); + g->dataMalloc(); + runtime->run(g, true); + runtime->run(g, false); + runtime->getPerfTime(g); + + auto time = runtime->timeWithCudaGraph(g); + EXPECT_GE(time, 0.01); +} + +} // namespace infini