From d2d49c5d4f0a8fa7c6b7159d21881223a58d9b23 Mon Sep 17 00:00:00 2001 From: Liyan Zheng Date: Thu, 13 Apr 2023 21:43:56 +0800 Subject: [PATCH] Add: invoke TVM through pipe --- python/cpp_plugin/gen_ansor_so.py | 20 ++- .../cuda/membound_tvm_packed_function.cc | 116 ++++++++++++++++-- 2 files changed, 122 insertions(+), 14 deletions(-) diff --git a/python/cpp_plugin/gen_ansor_so.py b/python/cpp_plugin/gen_ansor_so.py index 9580b3c7..0c67bb02 100644 --- a/python/cpp_plugin/gen_ansor_so.py +++ b/python/cpp_plugin/gen_ansor_so.py @@ -1,12 +1,13 @@ +import os +import sys +import json from contextlib import redirect_stdout import time +import logging import numpy as np import tvm from tvm import te, tir, auto_scheduler, topi -import os -import json -import logging USE_CACHE = True logging.basicConfig() @@ -16,7 +17,7 @@ logger.setLevel(logging.INFO) def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype, tvm_code, func_name, nnet_expression: str, - nnet_simplified_expression: str, hash_code=None): + nnet_simplified_expression: str, hash_code: str = None): assert len(input_tensors) == len(input_dtypes) logger.debug(f'Work on hash {hash_code}') @@ -117,3 +118,14 @@ def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype, }, ensure_ascii=False, indent=2)) return so_fn, conv_time + +# Read arguments from pipe, which is redirected to stdin. +# Write generated library path to pipe. + + +def pipe_gen(fd: int): + args = json.load(sys.stdin) # read from pipe + # print(args, f'fd={fd}') + ret = gen_ansor_so(**args) + with os.fdopen(fd, 'w') as f: + print(ret[0], file=f, end='') # write to pipe diff --git a/src/kernels/cuda/membound_tvm_packed_function.cc b/src/kernels/cuda/membound_tvm_packed_function.cc index 8086518d..83154c2c 100644 --- a/src/kernels/cuda/membound_tvm_packed_function.cc +++ b/src/kernels/cuda/membound_tvm_packed_function.cc @@ -8,6 +8,12 @@ #include "operators/pooling.h" #include "tvm/runtime/module.h" #include "tvm/runtime/packed_func.h" +#include +#include +#include +#include +#include +using json = nlohmann::json; namespace py = pybind11; @@ -124,7 +130,8 @@ class MemboundTVMPackedFunction : public Kernel { return std::dynamic_pointer_cast(ret); } - /// @brief + /// @brief Invoke TVM with pybind11. This approach is deprecated since it + /// ruins the python global scope and somehow breaks the cuBLAS runtime. /// @param inDims /// @param inDTypes /// @param outDims @@ -137,15 +144,15 @@ class MemboundTVMPackedFunction : public Kernel { /// @param hashCode (optional) Hash code of the input expression for kernel /// cache. /// @return - std::string getAnsorDLL(const std::vector> &inDims, - const std::vector &inDTypes, - const std::vector &outDims, - const std::string &outDType, - const std::string &lambda, - const std::string &funcName, - const std::string &nnetExprString, - const std::string &nnetSimplifiedExprString, - const HashType hashCode) const { + std::string getAnsorDLLPybind11(const std::vector> &inDims, + const std::vector &inDTypes, + const std::vector &outDims, + const std::string &outDType, + const std::string &lambda, + const std::string &funcName, + const std::string &nnetExprString, + const std::string &nnetSimplifiedExprString, + const HashType hashCode) const { std::string dllPath; try { start_interpreter(); @@ -171,6 +178,95 @@ class MemboundTVMPackedFunction : public Kernel { return dllPath; } + std::string serializeTVMArgs(const std::vector> &inDims, + const std::vector &inDTypes, + const std::vector &outDims, + const std::string &outDType, + const std::string &lambda, + const std::string &funcName, + const std::string &nnetExprString, + const std::string &nnetSimplifiedExprString, + const HashType hashCode) const { + json j; + // Consistant with python API interface + j["input_tensors"] = inDims; + j["input_dtypes"] = inDTypes; + j["output_tensor"] = outDims; + j["output_dtype"] = outDType; + j["tvm_code"] = lambda; + j["func_name"] = funcName; + j["nnet_expression"] = nnetExprString; + j["nnet_simplified_expression"] = nnetSimplifiedExprString; + j["hash_code"] = std::to_string(hashCode); + return j.dump(); + } + + std::string getAnsorDLL(const std::vector> &inDims, + const std::vector &inDTypes, + const std::vector &outDims, + const std::string &outDType, + const std::string &lambda, + const std::string &funcName, + const std::string &nnetExprString, + const std::string &nnetSimplifiedExprString, + const HashType hashCode) const { + int fdP2C[2], fdC2P[2]; + for (auto fd : {fdP2C, fdC2P}) { + int status = pipe(fd); + IT_ASSERT(status == 0, "pipe failed"); + } + pid_t pid = fork(); + IT_ASSERT(pid >= 0, "fork failed"); + if (pid == 0) { // Child process + close(fdP2C[1]); + close(fdC2P[0]); + + dup2(fdP2C[0], STDIN_FILENO); + close(fdP2C[0]); + + string cmd = + "from cpp_plugin.gen_ansor_so import pipe_gen; pipe_gen(+" + + std::to_string(fdC2P[1]) + ")"; + const char *const argv[] = {"python3", "-c", cmd.data(), NULL}; + execvp("python3", const_cast(argv)); + } else { // Parent process + close(fdP2C[0]); + close(fdC2P[1]); + + // Write to pipe + string serializedArgs = serializeTVMArgs( + inDims, inDTypes, outDims, outDType, lambda, funcName, + nnetExprString, nnetSimplifiedExprString, hashCode); + int status = -1; + status = + write(fdP2C[1], serializedArgs.data(), serializedArgs.size()); + IT_ASSERT(status == serializedArgs.size(), + "Failed to write to pipe"); + close(fdP2C[1]); + + // Wait for TVM + waitpid(pid, &status, 0); + IT_ASSERT(WIFEXITED(status), "TVM process was terminated"); + const int es = WEXITSTATUS(status); + IT_ASSERT(es == 0, + "TVM process exit with code " + std::to_string(es)); + + // Read from pipe + FILE *stream; + int c; + stream = fdopen(fdC2P[0], "r"); + char buf_read[257] = {0}; + status = std::fscanf(stream, "%256c", buf_read); + IT_ASSERT(status == 1, "Failed to read from pipe"); + IT_ASSERT(buf_read[256] == 0, "Pipe buffer overflow"); + fclose(stream); + close(fdC2P[0]); + return buf_read; + } + IT_ASSERT(false, "Should not reach here"); + return ""; + } + tvm::runtime::PackedFunc getPackedFunction(string path, string functionName) const { tvm::runtime::Module mod = tvm::runtime::Module::LoadFromFile(path);