From 11d5aa1ccce54de4695a824e8c2b0d6379ab93b4 Mon Sep 17 00:00:00 2001 From: deathwings602 <59163199+deathwings602@users.noreply.github.com> Date: Thu, 22 Sep 2022 18:06:45 +0800 Subject: [PATCH] Add TVM codegen for MemboundOp (#35) * Add: interface for membound TVM kernel and test * add getAnsorCode * add evaluation, but link failed * add evaluation of kernel, but link failed * Fix: link libcuda and nvrtc * add print * Add: const for source of copy * compile and evaluate the kernel * add compute * fix gen_ansor_op.py * fix membound_TVM * format and fix CMakeLists.txt * fix memory leak Co-authored-by: Liyan Zheng Co-authored-by: huangshuhong --- CMakeLists.txt | 26 ++-- include/core/runtime.h | 16 ++- include/core/tensor.h | 9 +- include/cuda/cuda_common.h | 12 ++ include/cuda/cuda_runtime.h | 17 ++- include/ffi/ffi_embed.h | 8 ++ python/cpp_plugin/__init__.py | 1 + python/cpp_plugin/gen_ansor_op.py | 94 +++++++++++++++ src/core/runtime.cc | 8 +- src/ffi/ffi_embed.cc | 11 ++ src/kernels/cuda/membound_TVM.cc | 191 ++++++++++++++++++++++++++++++ test/nnet/test_memboundOp.cc | 39 ++++++ 12 files changed, 402 insertions(+), 30 deletions(-) create mode 100644 include/ffi/ffi_embed.h create mode 100644 python/cpp_plugin/__init__.py create mode 100644 python/cpp_plugin/gen_ansor_op.py create mode 100644 src/ffi/ffi_embed.cc create mode 100644 src/kernels/cuda/membound_TVM.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a09a5e0..63762f62 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.10) # Required by CMAKE_CUDA_HOST_COMPILER +cmake_minimum_required(VERSION 3.17) # FindCUDAToolkit include(CMakeDependentOption) project(InfiniTensor C CXX) @@ -52,12 +52,11 @@ if(USE_PROTOBUF) target_link_libraries(tensor_proto PUBLIC ${PROTOBUF_LIBRARIES}) endif() - include_directories(include) -# # Pybind11 -# add_subdirectory(3rd-party/pybind11) -# include_directories(3rd-party/pybind11/include) +# Pybind11 +add_subdirectory(3rd-party/pybind11) +include_directories(3rd-party/pybind11/include) # nlohmann_json add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) @@ -75,7 +74,7 @@ if(BUILD_TEST) endif() # Source files -file(GLOB_RECURSE SRC src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc) +file(GLOB_RECURSE SRC src/ffi/*.cc src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc) if(USE_CUDA) file(GLOB_RECURSE SRC_CUDA src/cuda/*.cc src/cuda/*.cu src/kernels/cuda/*.cc src/kernels/cuda/*.cu) @@ -93,6 +92,8 @@ if(USE_PROTOBUF) target_link_libraries(InfiniTensor tensor_proto) endif() +target_link_libraries(InfiniTensor pybind11::embed) + if(USE_BACKTRACE) add_definitions(-D BACKWARD_TRACE) add_subdirectory(3rd-party/backward-cpp) @@ -109,10 +110,15 @@ if(USE_CUDA) # CMP0104 requires CUDA_ARCHITECTURES set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80") enable_language(CUDA) - # TODO: find_package seems unnecessary for CMake >= 3.8 - find_package(CUDA REQUIRED) - # message("CUBLAS_LIBRARIES: ${CUDA_LIBRARIES}") - target_link_libraries(InfiniTensor cudnn curand cublas ${CUDA_LIBRARIES}) + find_package(CUDAToolkit) # For nvrtc and cuda driver + target_link_libraries( + InfiniTensor + cudnn + CUDA::curand + CUDA::cublas + CUDA::nvrtc + CUDA::cudart + CUDA::cuda_driver) endif() if(USE_BANG) diff --git a/include/core/runtime.h b/include/core/runtime.h index d6f05ade..b47b6587 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -66,14 +66,17 @@ class RuntimeObj : public std::enable_shared_from_this { bool isCuda() const { return device == Device::CUDA; } bool isBang() const { return device == Device::BANG; } void copyBlob(const TensorObj *dst, const TensorObj *src) const; + // TODO: unify these copy APIs + virtual void copyBlobFromCPU(void *dst, const void *src, + size_t bytes) const = 0; protected: void printProfilingData(double totTime, const std::map &opTime, const std::map &opCnt) const; - virtual void copyBlobFromCPU(void *dst, void *src, size_t bytes) const = 0; - virtual void copyBlobToCPU(void *dst, void *src, size_t bytes) const = 0; - virtual void copyBlobInsideRuntime(void *dst, void *src, + virtual void copyBlobToCPU(void *dst, const void *src, + size_t bytes) const = 0; + virtual void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const = 0; }; @@ -94,9 +97,10 @@ class CpuRuntimeObj : public RuntimeObj { sizeof(uint64_t)); }; - void copyBlobFromCPU(void *dst, void *src, size_t bytes) const override; - void copyBlobToCPU(void *dst, void *src, size_t bytes) const override; - void copyBlobInsideRuntime(void *dst, void *src, + void copyBlobFromCPU(void *dst, const void *src, + size_t bytes) const override; + void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override; + void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const override; }; diff --git a/include/core/tensor.h b/include/core/tensor.h index 031c4f8d..9b9b4237 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -32,14 +32,7 @@ class TensorObj : public TensorBaseObj { template void copyData(const T *dptr) { IT_ASSERT(DataType::get() == dtype); IT_ASSERT(data != nullptr); - if (!runtime->isCpu()) - IT_TODO_HALT(); - auto ptr = data->getPtr(); - size_t sz = size(); -#pragma omp parallel for - for (size_t i = 0; i < sz; ++i) { - ptr[i] = dptr[i]; - } + runtime->copyBlobFromCPU(getRawDataPtr(), dptr, getBytes()); } template void copyData(vector dataVector) { diff --git a/include/cuda/cuda_common.h b/include/cuda/cuda_common.h index 9b4680e2..c24ecdca 100644 --- a/include/cuda/cuda_common.h +++ b/include/cuda/cuda_common.h @@ -17,6 +17,18 @@ } \ } +#define checkCUresult(call) \ + { \ + auto err = call; \ + const char *errName; \ + if (CUDA_SUCCESS != err) { \ + cuGetErrorString(err, &errName); \ + fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \ + errName); \ + exit(EXIT_FAILURE); \ + } \ + } + #define checkCublasError(call) \ { \ auto err = call; \ diff --git a/include/cuda/cuda_runtime.h b/include/cuda/cuda_runtime.h index ac2c3092..305efc62 100644 --- a/include/cuda/cuda_runtime.h +++ b/include/cuda/cuda_runtime.h @@ -11,6 +11,10 @@ class CudaRuntimeObj : public RuntimeObj { CudaPtr workspace; size_t workspaceSize; + public: + CUdevice cuDevice; + CUcontext newContext; + public: CudaRuntimeObj() : RuntimeObj(Device::CUDA) { checkCudnnError(cudnnCreate(&cudnn)); @@ -19,11 +23,16 @@ class CudaRuntimeObj : public RuntimeObj { // size_t longformerNum = 3lu * (1 << 30); workspaceSize = 7ll << 30; // 7 GB workspace = alloc(workspaceSize); + + checkCUresult(cuInit(0)); + checkCUresult(cuDeviceGet(&cuDevice, 0)); + checkCUresult(cuCtxCreate(&newContext, 0, cuDevice)); } virtual ~CudaRuntimeObj() { dealloc(workspace); checkCudnnError(cudnnDestroy(cudnn)); checkCublasError(cublasDestroy(cublas)); + checkCUresult(cuCtxDestroy(newContext)); } void run(const Graph &graph, bool tune = false, @@ -44,15 +53,17 @@ class CudaRuntimeObj : public RuntimeObj { return workspace; } - void copyBlobFromCPU(void *dst, void *src, size_t bytes) const override { + void copyBlobFromCPU(void *dst, const void *src, + size_t bytes) const override { checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyHostToDevice)); } - void copyBlobToCPU(void *dst, void *src, size_t bytes) const override { + void copyBlobToCPU(void *dst, const void *src, + size_t bytes) const override { checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToHost)); } - void copyBlobInsideRuntime(void *dst, void *src, + void copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const override { checkCudaError(cudaMemcpy(dst, src, bytes, cudaMemcpyDeviceToDevice)); } diff --git a/include/ffi/ffi_embed.h b/include/ffi/ffi_embed.h new file mode 100644 index 00000000..e69e8836 --- /dev/null +++ b/include/ffi/ffi_embed.h @@ -0,0 +1,8 @@ +#pragma once +#include +#include +#include + +namespace infini { +void start_interpreter(); +} // namespace infini \ No newline at end of file diff --git a/python/cpp_plugin/__init__.py b/python/cpp_plugin/__init__.py new file mode 100644 index 00000000..811587de --- /dev/null +++ b/python/cpp_plugin/__init__.py @@ -0,0 +1 @@ +from .gen_ansor_op import gen_ansor_op diff --git a/python/cpp_plugin/gen_ansor_op.py b/python/cpp_plugin/gen_ansor_op.py new file mode 100644 index 00000000..ead0033a --- /dev/null +++ b/python/cpp_plugin/gen_ansor_op.py @@ -0,0 +1,94 @@ +import re + +import numpy as np +import tvm +from tvm import te, tir, auto_scheduler, topi + + +def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, func_name, input_names, output_name): + assert len(input_tensors) == len(input_dtypes) + assert len(input_tensors) == len(input_names) + + print("Generating Ansor op: ") + print(f) + + @auto_scheduler.register_workload(func_name) + def compute(): + _locals = locals() + exec(f, {'tvm': tvm, 'te': te, 'tir': tir}, _locals) + return _locals['ret'] + + target = tvm.target.Target("cuda") + + task = auto_scheduler.SearchTask(func=func_name, args=(), target=target) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + + log_file = f"ansor_{func_name}_log.json" + measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, + runner=measure_ctx.runner, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + ) + + # Run auto-tuning (search) + task.tune(tune_option) + # Apply the best schedule + sch, args = task.apply_best(log_file) + + # Kill the measurement process + del measure_ctx + + ir = str(tvm.lower(sch, args, simple_mode=True)) + thread_dim = [1, 1, 1] + block_dim = [1, 1, 1] + p = re.compile('"thread_extent" = (\d+)') + for line in ir.splitlines(): + if "thread_extent" in line: + ext = int(p.search(line).group(1)) + if "threadIdx.x" in line: + thread_dim[0] = ext + elif "threadIdx.y" in line: + thread_dim[1] = ext + elif "threadIdx.z" in line: + thread_dim[2] = ext + elif "blockIdx.x" in line: + block_dim[0] = ext + elif "blockIdx.y" in line: + block_dim[1] = ext + elif "blockIdx.z" in line: + block_dim[2] = ext + + func = tvm.build(sch, args, target, name=func_name) + + func_code = func.imported_modules[0].get_source() + invoke_code = "%s_kernel0<<>>(%s, %s);" % ( + func_name, ", ".join(map(str, block_dim)), ", ".join( + map(str, thread_dim)), + output_name, ", ".join(input_names)) + invoke_params = block_dim + thread_dim + + ctx = tvm.cuda(0) + input_a = [] + for i, (shape, dtype) in enumerate(zip(input_tensors, input_dtypes)): + a_np = np.random.uniform(size=shape).astype(dtype) + input_a.append(tvm.nd.array(a_np, ctx)) + a_out = tvm.nd.array(np.zeros(output_tensor, dtype=output_dtype), ctx) + func(a_out, *input_a) + evaluator = func.time_evaluator(func.entry_name, ctx, number=100) + conv_time = evaluator(a_out, *input_a).mean * 1e3 + + print("Func Code") + # Attach TVM code behind func_code + func_code += "\n/* " + f + "*/" + print(func_code) + print("Invoke Code") + print(invoke_code) + print("Time") + print(conv_time) + + return func_code, invoke_code, conv_time, invoke_params # ms diff --git a/src/core/runtime.cc b/src/core/runtime.cc index 4d02f5ba..28f243fd 100644 --- a/src/core/runtime.cc +++ b/src/core/runtime.cc @@ -124,15 +124,17 @@ void RuntimeObj::copyBlob(const TensorObj *dst, const TensorObj *src) const { IT_TODO_HALT(); } -void CpuRuntimeObj::copyBlobFromCPU(void *dst, void *src, size_t bytes) const { +void CpuRuntimeObj::copyBlobFromCPU(void *dst, const void *src, + size_t bytes) const { copyBlobInsideRuntime(dst, src, bytes); } -void CpuRuntimeObj::copyBlobToCPU(void *dst, void *src, size_t bytes) const { +void CpuRuntimeObj::copyBlobToCPU(void *dst, const void *src, + size_t bytes) const { copyBlobInsideRuntime(dst, src, bytes); } -void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, void *src, +void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, const void *src, size_t bytes) const { memcpy(dst, src, bytes); } diff --git a/src/ffi/ffi_embed.cc b/src/ffi/ffi_embed.cc new file mode 100644 index 00000000..1cf70d3b --- /dev/null +++ b/src/ffi/ffi_embed.cc @@ -0,0 +1,11 @@ +#include "ffi/ffi_embed.h" +namespace py = pybind11; + +namespace infini { +void start_interpreter() { + static py::scoped_interpreter *interpretor = nullptr; + if (interpretor == nullptr) { + interpretor = new py::scoped_interpreter(); // guard + } +} +} // namespace infini diff --git a/src/kernels/cuda/membound_TVM.cc b/src/kernels/cuda/membound_TVM.cc new file mode 100644 index 00000000..6bf4c4b2 --- /dev/null +++ b/src/kernels/cuda/membound_TVM.cc @@ -0,0 +1,191 @@ +#include "core/kernel.h" +#include "cuda/cuda_runtime.h" +#include "ffi/ffi_embed.h" +#include "nnet/Visitor/AsTVMVisitor.h" +#include "nvrtc.h" +#include "operators/membound.h" +#include "operators/pooling.h" + +namespace py = pybind11; + +namespace infini { + +class TVMRecordObj : public PerfRecordObj { + // TODO: Add more attrs + public: + size_t logSize, ptxSize; + std::string log, ptx; + std::vector invokeParams; + std::string kernelName; +}; + +using TVMRecord = Ref; + +class MemboundTVM : public Kernel { + public: + void compute(const Operator &_op, const PerfRecord &record, + const RuntimeObj *_context) const override { + auto op = as(_op); + // auto context = dynamic_cast(_context); + auto tvmRecord = std::dynamic_pointer_cast(record); + + // prepare for evaluation + CUmodule module; + CUfunction kernel; + checkCUresult(cuModuleLoadDataEx(&module, tvmRecord->ptx.data(), 0, + nullptr, nullptr)); + checkCUresult(cuModuleGetFunction(&kernel, module, + tvmRecord->kernelName.c_str())); + std::vector args; + for (auto &&in : op->getInputs()) { + args.push_back(in->getRawDataPtr()); + } + args.push_back(op->getOutput()->getRawDataPtr()); + std::vector argsPtr; + for (auto &arg : args) { + argsPtr.push_back(&arg); + } + auto invokeParams = tvmRecord->invokeParams; + + // begin evaluation + cuLaunchKernel(kernel, invokeParams[0], invokeParams[1], + invokeParams[2], invokeParams[3], invokeParams[4], + invokeParams[5], 0, NULL, argsPtr.data(), 0); + + // free module + checkCUresult(cuModuleUnload(module)); + } + + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + IT_ASSERT(false, "A TVM record is required for membound kernel."); + } + + std::string getVarName(const Tensor &t) const { + return "var_" + std::to_string(t->getGuid()); + } + + // Premise: op is idempotent since it is called multiple times. + PerfRecord tune(const Operator &_op, + const RuntimeObj *_context) const override { + TVMRecord ret = std::make_shared(); + auto op = as(_op); + auto context = dynamic_cast(_context); + + // invoke Ansor to tune a membound kernel + std::string func = "mem_bound_" + std::to_string(op->getGuid()); + std::string kernelName = func + "_kernel0"; + nnet::AsTVMVisitor visitor; + visitor.dispatch(op->getNnetExpr()); + auto &&stmts = visitor.getStmts(); + auto &&inShapes = visitor.getInputShapes(); + auto &&outShape = visitor.getOutputShape(); + + std::vector inputs; + for (auto &&in : op->getInputs()) { + inputs.emplace_back(getVarName(in)); + } + std::string output = getVarName(op->getOutput()); + auto res = getAnsorCode( + inShapes, std::vector(inShapes.size(), "float32"), + outShape, "float32", stmts, func, inputs, output); + + // compile the kernel + auto funcCode = res.first; + auto invokeParams = res.second; + std::string fileName = func + ".cu"; + nvrtcProgram prog; + nvrtcCreateProgram(&prog, // prog + funcCode.c_str(), // buffer + fileName.c_str(), // name + 0, // numHeaders + NULL, // headers + NULL); // includeNames + const char *opts[] = {"--gpu-architecture=compute_80", "--fmad=false"}; + nvrtcCompileProgram(prog, // prog + 2, // numOptions + opts); // options + + // copy ptx and log to ret + size_t logSize; + nvrtcGetProgramLogSize(prog, &logSize); + size_t ptxSize; + nvrtcGetPTXSize(prog, &ptxSize); + ret->logSize = logSize; + ret->ptxSize = ptxSize; + ret->log = std::string(logSize, ' '); + ret->ptx = std::string(ptxSize, ' '); + nvrtcGetProgramLog(prog, ret->log.data()); + nvrtcGetPTX(prog, ret->ptx.data()); + ret->invokeParams = invokeParams; + ret->kernelName = kernelName; + + // prepare for evaluation + CUmodule module; + CUfunction kernel; + checkCUresult( + cuModuleLoadDataEx(&module, ret->ptx.data(), 0, nullptr, nullptr)); + checkCUresult(cuModuleGetFunction(&kernel, module, kernelName.c_str())); + std::vector args; + for (auto &&in : op->getInputs()) { + args.push_back(in->getRawDataPtr()); + } + args.push_back(op->getOutput()->getRawDataPtr()); + std::vector argsPtr; + for (auto &arg : args) { + argsPtr.push_back(&arg); + } + + // Evaluate the kernel + ret->time = timeit( + [&]() { + // TODO: run the kernel + cuLaunchKernel(kernel, invokeParams[0], invokeParams[1], + invokeParams[2], invokeParams[3], + invokeParams[4], invokeParams[5], 0, NULL, + argsPtr.data(), 0); + }, + [&]() { context->sync(); }); + + // free module + checkCUresult(cuModuleUnload(module)); + nvrtcDestroyProgram(&prog); + + return std::dynamic_pointer_cast(ret); + } + + std::pair> + getAnsorCode(const std::vector> &inDims, + const std::vector &inDTypes, + const std::vector &outDims, const std::string &outDType, + const std::string &lambda, const std::string &funcName, + const std::vector &inputNames, + const std::string &outputName) const { + std::string funcCode; + std::vector invokeParams; + try { + start_interpreter(); + auto func = py::module::import("cpp_plugin").attr("gen_ansor_op"); + py::tuple code = func(inDims, inDTypes, outDims, outDType, lambda, + funcName, inputNames, outputName); + funcCode = py::str(code[0]); + auto temp = py::list(code[3]); + for (int i = 0; i < 6; ++i) { + invokeParams.push_back(temp[i].cast()); + } + } catch (py::error_already_set &e) { + if (e.matches(PyExc_ImportError)) { + std::cerr << "Import Error. Don't forget to set environment " + "variable PYTHONPATH to contain " + "/python" + << std::endl; + } + throw; + } + return std::make_pair(funcCode, invokeParams); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32, MemboundTVM, + "Memobund_TVM_Ansor"); +}; // namespace infini \ No newline at end of file diff --git a/test/nnet/test_memboundOp.cc b/test/nnet/test_memboundOp.cc index af43770b..29c0d7a7 100644 --- a/test/nnet/test_memboundOp.cc +++ b/test/nnet/test_memboundOp.cc @@ -1,5 +1,6 @@ #include "core/graph.h" #include "core/runtime.h" +#include "cuda/cuda_runtime.h" #include "nnet/Visitor/MatchReshapeVisitor.h" #include "nnet/expr.h" #include "nnet/nmutator.h" @@ -37,4 +38,42 @@ TEST(nnet, MemboundOpInterpretation) { ans->dataMalloc(); ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); EXPECT_TRUE(membound->getOutput()->equalData(ans)); +} + +TEST(nnet, MemboundOp_Ansor_Codegen) { + auto runtime = make_ref(); + Runtime cpu = CpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpu); + Graph g = make_ref(runtime); + Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32); + Tensor w0 = g->addTensor({1, 3, 4}, DataType::Float32); + Tensor o0 = g->addTensor({1, 2, 4}, DataType::Float32); + g->dataMalloc(); + i0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + w0->copyData(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + g->addOpWithOutputs(i0, w0, o0); + NMutator nmutator(NMutator::Mode::ToNaiveMembound); + auto mutations = nmutator.run(g); + ASSERT_EQ(mutations.size(), 1u); + Graph gNew = mutations[0]; + gNew->print(); + gNew->dataMalloc(); + runtime->run(gNew, true); // tune kernels + + // check answer + auto ops = gNew->getOperators(); + EXPECT_EQ(ops.size(), 1u); + auto membound = ops[0]; + EXPECT_EQ(membound->getOpType(), OpType::MemBound); + auto ans = make_ref(Shape{1, 2, 4}, DataType::Float32, cpu); + ans->dataMalloc(); + ans->copyData(vector{38, 44, 50, 56, 83, 98, 113, 128}); + + auto oCpu = gCpu->cloneTensor(membound->getOutput()); + oCpu->printData(); + EXPECT_TRUE(oCpu->equalData(ans)); + + // Timing + // double time = timeit([&]() { runtime->run(gNew, false); }); // tune + // kernels std::cout << "Time (ms):" << time << std::endl; } \ No newline at end of file