forked from jiuyuan/InfiniTensor
Add: invoke TVM through pipe
This commit is contained in:
parent
e72fe79168
commit
d2d49c5d4f
|
@ -1,12 +1,13 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
from contextlib import redirect_stdout
|
||||
import time
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import tvm
|
||||
from tvm import te, tir, auto_scheduler, topi
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
USE_CACHE = True
|
||||
logging.basicConfig()
|
||||
|
@ -16,7 +17,7 @@ logger.setLevel(logging.INFO)
|
|||
|
||||
def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype,
|
||||
tvm_code, func_name, nnet_expression: str,
|
||||
nnet_simplified_expression: str, hash_code=None):
|
||||
nnet_simplified_expression: str, hash_code: str = None):
|
||||
assert len(input_tensors) == len(input_dtypes)
|
||||
|
||||
logger.debug(f'Work on hash {hash_code}')
|
||||
|
@ -117,3 +118,14 @@ def gen_ansor_so(input_tensors, input_dtypes, output_tensor, output_dtype,
|
|||
}, ensure_ascii=False, indent=2))
|
||||
|
||||
return so_fn, conv_time
|
||||
|
||||
# Read arguments from pipe, which is redirected to stdin.
|
||||
# Write generated library path to pipe.
|
||||
|
||||
|
||||
def pipe_gen(fd: int):
|
||||
args = json.load(sys.stdin) # read from pipe
|
||||
# print(args, f'fd={fd}')
|
||||
ret = gen_ansor_so(**args)
|
||||
with os.fdopen(fd, 'w') as f:
|
||||
print(ret[0], file=f, end='') # write to pipe
|
||||
|
|
|
@ -8,6 +8,12 @@
|
|||
#include "operators/pooling.h"
|
||||
#include "tvm/runtime/module.h"
|
||||
#include "tvm/runtime/packed_func.h"
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -124,7 +130,8 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
|
||||
}
|
||||
|
||||
/// @brief
|
||||
/// @brief Invoke TVM with pybind11. This approach is deprecated since it
|
||||
/// ruins the python global scope and somehow breaks the cuBLAS runtime.
|
||||
/// @param inDims
|
||||
/// @param inDTypes
|
||||
/// @param outDims
|
||||
|
@ -137,7 +144,7 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
/// @param hashCode (optional) Hash code of the input expression for kernel
|
||||
/// cache.
|
||||
/// @return
|
||||
std::string getAnsorDLL(const std::vector<std::vector<int>> &inDims,
|
||||
std::string getAnsorDLLPybind11(const std::vector<std::vector<int>> &inDims,
|
||||
const std::vector<std::string> &inDTypes,
|
||||
const std::vector<int> &outDims,
|
||||
const std::string &outDType,
|
||||
|
@ -171,6 +178,95 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
return dllPath;
|
||||
}
|
||||
|
||||
std::string serializeTVMArgs(const std::vector<std::vector<int>> &inDims,
|
||||
const std::vector<std::string> &inDTypes,
|
||||
const std::vector<int> &outDims,
|
||||
const std::string &outDType,
|
||||
const std::string &lambda,
|
||||
const std::string &funcName,
|
||||
const std::string &nnetExprString,
|
||||
const std::string &nnetSimplifiedExprString,
|
||||
const HashType hashCode) const {
|
||||
json j;
|
||||
// Consistant with python API interface
|
||||
j["input_tensors"] = inDims;
|
||||
j["input_dtypes"] = inDTypes;
|
||||
j["output_tensor"] = outDims;
|
||||
j["output_dtype"] = outDType;
|
||||
j["tvm_code"] = lambda;
|
||||
j["func_name"] = funcName;
|
||||
j["nnet_expression"] = nnetExprString;
|
||||
j["nnet_simplified_expression"] = nnetSimplifiedExprString;
|
||||
j["hash_code"] = std::to_string(hashCode);
|
||||
return j.dump();
|
||||
}
|
||||
|
||||
std::string getAnsorDLL(const std::vector<std::vector<int>> &inDims,
|
||||
const std::vector<std::string> &inDTypes,
|
||||
const std::vector<int> &outDims,
|
||||
const std::string &outDType,
|
||||
const std::string &lambda,
|
||||
const std::string &funcName,
|
||||
const std::string &nnetExprString,
|
||||
const std::string &nnetSimplifiedExprString,
|
||||
const HashType hashCode) const {
|
||||
int fdP2C[2], fdC2P[2];
|
||||
for (auto fd : {fdP2C, fdC2P}) {
|
||||
int status = pipe(fd);
|
||||
IT_ASSERT(status == 0, "pipe failed");
|
||||
}
|
||||
pid_t pid = fork();
|
||||
IT_ASSERT(pid >= 0, "fork failed");
|
||||
if (pid == 0) { // Child process
|
||||
close(fdP2C[1]);
|
||||
close(fdC2P[0]);
|
||||
|
||||
dup2(fdP2C[0], STDIN_FILENO);
|
||||
close(fdP2C[0]);
|
||||
|
||||
string cmd =
|
||||
"from cpp_plugin.gen_ansor_so import pipe_gen; pipe_gen(+" +
|
||||
std::to_string(fdC2P[1]) + ")";
|
||||
const char *const argv[] = {"python3", "-c", cmd.data(), NULL};
|
||||
execvp("python3", const_cast<char *const *>(argv));
|
||||
} else { // Parent process
|
||||
close(fdP2C[0]);
|
||||
close(fdC2P[1]);
|
||||
|
||||
// Write to pipe
|
||||
string serializedArgs = serializeTVMArgs(
|
||||
inDims, inDTypes, outDims, outDType, lambda, funcName,
|
||||
nnetExprString, nnetSimplifiedExprString, hashCode);
|
||||
int status = -1;
|
||||
status =
|
||||
write(fdP2C[1], serializedArgs.data(), serializedArgs.size());
|
||||
IT_ASSERT(status == serializedArgs.size(),
|
||||
"Failed to write to pipe");
|
||||
close(fdP2C[1]);
|
||||
|
||||
// Wait for TVM
|
||||
waitpid(pid, &status, 0);
|
||||
IT_ASSERT(WIFEXITED(status), "TVM process was terminated");
|
||||
const int es = WEXITSTATUS(status);
|
||||
IT_ASSERT(es == 0,
|
||||
"TVM process exit with code " + std::to_string(es));
|
||||
|
||||
// Read from pipe
|
||||
FILE *stream;
|
||||
int c;
|
||||
stream = fdopen(fdC2P[0], "r");
|
||||
char buf_read[257] = {0};
|
||||
status = std::fscanf(stream, "%256c", buf_read);
|
||||
IT_ASSERT(status == 1, "Failed to read from pipe");
|
||||
IT_ASSERT(buf_read[256] == 0, "Pipe buffer overflow");
|
||||
fclose(stream);
|
||||
close(fdC2P[0]);
|
||||
return buf_read;
|
||||
}
|
||||
IT_ASSERT(false, "Should not reach here");
|
||||
return "";
|
||||
}
|
||||
|
||||
tvm::runtime::PackedFunc getPackedFunction(string path,
|
||||
string functionName) const {
|
||||
tvm::runtime::Module mod = tvm::runtime::Module::LoadFromFile(path);
|
||||
|
|
Loading…
Reference in New Issue