Add: invoke TVM through pipe

This commit is contained in:
Liyan Zheng 2023-04-13 21:43:56 +08:00
parent e72fe79168
commit d2d49c5d4f
2 changed files with 122 additions and 14 deletions

View File

@ -1,12 +1,13 @@
import os
import sys
import json
from contextlib import redirect_stdout from contextlib import redirect_stdout
import time import time
import logging
import numpy as np import numpy as np
import tvm import tvm
from tvm import te, tir, auto_scheduler, topi from tvm import te, tir, auto_scheduler, topi
import os
import json
import logging
USE_CACHE = True USE_CACHE = True
logging.basicConfig() logging.basicConfig()
@ -16,7 +17,7 @@ logger.setLevel(logging.INFO)
def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype, def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype,
tvm_code, func_name, nnet_expression: str, tvm_code, func_name, nnet_expression: str,
nnet_simplified_expression: str, hash_code=None): nnet_simplified_expression: str, hash_code: str = None):
assert len(input_tensors) == len(input_dtypes) assert len(input_tensors) == len(input_dtypes)
logger.debug(f'Work on hash {hash_code}') logger.debug(f'Work on hash {hash_code}')
@ -117,3 +118,14 @@ def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype,
}, ensure_ascii=False, indent=2)) }, ensure_ascii=False, indent=2))
return so_fn, conv_time return so_fn, conv_time
# Read arguments from pipe, which is redirected to stdin.
# Write generated library path to pipe.
def pipe_gen(fd: int):
args = json.load(sys.stdin) # read from pipe
# print(args, f'fd={fd}')
ret = gen_ansor_so(**args)
with os.fdopen(fd, 'w') as f:
print(ret[0], file=f, end='') # write to pipe

View File

@ -8,6 +8,12 @@
#include "operators/pooling.h" #include "operators/pooling.h"
#include "tvm/runtime/module.h" #include "tvm/runtime/module.h"
#include "tvm/runtime/packed_func.h" #include "tvm/runtime/packed_func.h"
#include <nlohmann/json.hpp>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>
using json = nlohmann::json;
namespace py = pybind11; namespace py = pybind11;
@ -124,7 +130,8 @@ class MemboundTVMPackedFunction : public Kernel {
return std::dynamic_pointer_cast<PerfRecordObj>(ret); return std::dynamic_pointer_cast<PerfRecordObj>(ret);
} }
/// @brief /// @brief Invoke TVM with pybind11. This approach is deprecated since it
/// ruins the python global scope and somehow breaks the cuBLAS runtime.
/// @param inDims /// @param inDims
/// @param inDTypes /// @param inDTypes
/// @param outDims /// @param outDims
@ -137,15 +144,15 @@ class MemboundTVMPackedFunction : public Kernel {
/// @param hashCode (optional) Hash code of the input expression for kernel /// @param hashCode (optional) Hash code of the input expression for kernel
/// cache. /// cache.
/// @return /// @return
std::string getAnsorDLL(const std::vector<std::vector<int>> &inDims, std::string getAnsorDLLPybind11(const std::vector<std::vector<int>> &inDims,
const std::vector<std::string> &inDTypes, const std::vector<std::string> &inDTypes,
const std::vector<int> &outDims, const std::vector<int> &outDims,
const std::string &outDType, const std::string &outDType,
const std::string &lambda, const std::string &lambda,
const std::string &funcName, const std::string &funcName,
const std::string &nnetExprString, const std::string &nnetExprString,
const std::string &nnetSimplifiedExprString, const std::string &nnetSimplifiedExprString,
const HashType hashCode) const { const HashType hashCode) const {
std::string dllPath; std::string dllPath;
try { try {
start_interpreter(); start_interpreter();
@ -171,6 +178,95 @@ class MemboundTVMPackedFunction : public Kernel {
return dllPath; return dllPath;
} }
std::string serializeTVMArgs(const std::vector<std::vector<int>> &inDims,
const std::vector<std::string> &inDTypes,
const std::vector<int> &outDims,
const std::string &outDType,
const std::string &lambda,
const std::string &funcName,
const std::string &nnetExprString,
const std::string &nnetSimplifiedExprString,
const HashType hashCode) const {
json j;
// Consistant with python API interface
j["input_tensors"] = inDims;
j["input_dtypes"] = inDTypes;
j["output_tensor"] = outDims;
j["output_dtype"] = outDType;
j["tvm_code"] = lambda;
j["func_name"] = funcName;
j["nnet_expression"] = nnetExprString;
j["nnet_simplified_expression"] = nnetSimplifiedExprString;
j["hash_code"] = std::to_string(hashCode);
return j.dump();
}
std::string getAnsorDLL(const std::vector<std::vector<int>> &inDims,
const std::vector<std::string> &inDTypes,
const std::vector<int> &outDims,
const std::string &outDType,
const std::string &lambda,
const std::string &funcName,
const std::string &nnetExprString,
const std::string &nnetSimplifiedExprString,
const HashType hashCode) const {
int fdP2C[2], fdC2P[2];
for (auto fd : {fdP2C, fdC2P}) {
int status = pipe(fd);
IT_ASSERT(status == 0, "pipe failed");
}
pid_t pid = fork();
IT_ASSERT(pid >= 0, "fork failed");
if (pid == 0) { // Child process
close(fdP2C[1]);
close(fdC2P[0]);
dup2(fdP2C[0], STDIN_FILENO);
close(fdP2C[0]);
string cmd =
"from cpp_plugin.gen_ansor_so import pipe_gen; pipe_gen(+" +
std::to_string(fdC2P[1]) + ")";
const char *const argv[] = {"python3", "-c", cmd.data(), NULL};
execvp("python3", const_cast<char *const *>(argv));
} else { // Parent process
close(fdP2C[0]);
close(fdC2P[1]);
// Write to pipe
string serializedArgs = serializeTVMArgs(
inDims, inDTypes, outDims, outDType, lambda, funcName,
nnetExprString, nnetSimplifiedExprString, hashCode);
int status = -1;
status =
write(fdP2C[1], serializedArgs.data(), serializedArgs.size());
IT_ASSERT(status == serializedArgs.size(),
"Failed to write to pipe");
close(fdP2C[1]);
// Wait for TVM
waitpid(pid, &status, 0);
IT_ASSERT(WIFEXITED(status), "TVM process was terminated");
const int es = WEXITSTATUS(status);
IT_ASSERT(es == 0,
"TVM process exit with code " + std::to_string(es));
// Read from pipe
FILE *stream;
int c;
stream = fdopen(fdC2P[0], "r");
char buf_read[257] = {0};
status = std::fscanf(stream, "%256c", buf_read);
IT_ASSERT(status == 1, "Failed to read from pipe");
IT_ASSERT(buf_read[256] == 0, "Pipe buffer overflow");
fclose(stream);
close(fdC2P[0]);
return buf_read;
}
IT_ASSERT(false, "Should not reach here");
return "";
}
tvm::runtime::PackedFunc getPackedFunction(string path, tvm::runtime::PackedFunc getPackedFunction(string path,
string functionName) const { string functionName) const {
tvm::runtime::Module mod = tvm::runtime::Module::LoadFromFile(path); tvm::runtime::Module mod = tvm::runtime::Module::LoadFromFile(path);