All mdoels E2E

This commit is contained in:
Liyan Zheng 2023-04-25 04:24:43 +08:00
parent 350fc01d39
commit b13b799fbe
11 changed files with 122 additions and 33 deletions

View File

@ -16,6 +16,7 @@ class Mutator {
Runtime runtime = NativeCpuRuntimeObj::getInstance())
: candidatesLimit(candidatesLimit), runtime(runtime){};
virtual ~Mutator(){};
bool hasTunedKernel = false;
virtual vector<Graph> run(const Graph &in_graph) = 0;
/**

View File

@ -15,6 +15,7 @@ class SearchEngine {
SearchEngine(Runtime runtime, Ref<Mutator> mutator);
~SearchEngine() {}
int searchFilter = 0;
bool chooseBestMutation = true;
private: // Configurations
size_t partitionThreshold =

View File

@ -15,6 +15,7 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
void initializeGraphTensors(Graph g, double l, double r, bool useInt);
Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG);
Graph optimizeWithDepthConstraint(Graph g, Runtime _runtime, int maxDepth);
Graph optimizeModel(Graph g, Runtime _runtime, string name);
} // namespace infini

View File

@ -345,7 +345,8 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
std::vector<Graph> nextGraphs;
if (node.type == 1) { // If it has computing OPs
auto mutatedGraphs = mutator->run(node.graph);
constexpr bool chooseBestMutation = false;
if (mutator->hasTunedKernel)
chooseBestMutation = false;
if (searchFilter == 1) {
std::sort(mutatedGraphs.begin(), mutatedGraphs.end(),
graphTimeComparer);
@ -358,12 +359,13 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
if (mutatedGraphs.size() >= 10)
mutatedGraphs.resize(10);
mutatedGraphs = {mutatedGraphs[0]};
} else { // avoid repeated kernel genreation
if (mutatedGraphs.size() >= 2) // INFOGAN
mutatedGraphs = {mutatedGraphs[1]};
// if (mutatedGraphs.size() > 2) {
// mutatedGraphs.resize(2);
// }
}
// // HACK: only try the first one for debug
// if (mutatedGraphs.size() > 2)
// mutatedGraphs.resize(2);
// if (mutatedGraphs.size() >= 2)
// mutatedGraphs = {mutatedGraphs[1]};
for (auto graph : graphs) {
for (auto mutatedGraph : mutatedGraphs) {
@ -494,6 +496,14 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
ops.emplace_back(op);
continue;
}
if (op->getOpType() == OpType::Relu ||
op->getOpType() == OpType::PRelu) {
if (auto p = op->getInputs()[0])
if (auto sop = p->getSource())
if (sop->getOpType() == OpType::Conv ||
sop->getOpType() == OpType::Matmul)
continue;
}
vector<Operator> chainOps;
visitTime.emplace(op->getGuid(), ++cnt);

View File

@ -198,10 +198,11 @@ Tensor TensorObj::clone(Runtime runtime) const {
obj->freeData();
obj->targets.clear();
obj->source.reset();
if (hasData()) {
obj->dataMalloc();
obj->copyData(this);
}
// FIXME
// if (hasData()) {
// obj->dataMalloc();
// obj->copyData(this);
// }
return obj;
}

View File

@ -416,8 +416,8 @@ void export_test_model(py::module &m) {
.def("initializeGraphTensors", &initializeGraphTensors, "g"_a,
"l"_a = -0.1, "r"_a = 0.1, "useInt"_a = false)
.def("convertNCHWtoNHWCModel", &convertNCHWtoNHWCModel)
.def("optimizeWithDepthConstraint", &optimizeWithDepthConstraint);
.def("optimizeWithDepthConstraint", &optimizeWithDepthConstraint)
.def("optimizeModel", &optimizeModel);
#endif
}

View File

@ -27,6 +27,8 @@ using GANConfigs = vector<tuple<int, int, int, int, bool>>;
using DetailedConfigs =
vector<tuple<int, int, int, int, int, int, int, int, int, int, bool>>;
static const vector<int> metaRules = {3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90};
DetailedConfigs getGANConfigs(int id, int batch) {
// The first conv can be transformed into gemm without reduction
// n, f, h, w, c, r, s, stride,
@ -356,6 +358,18 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) {
return g;
}
Graph optimizeModel(Graph g, Runtime _runtime, string name) {
auto runtime = as<CudaRuntimeObj>(_runtime);
Runtime cpu = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(cpu);
Ref<NMutator> mutator =
make_ref<NMutator>(NMutator::Mode::RuleBased, metaRules, runtime);
vector<Graph> bestGraphs;
SearchEngine searchEngine(runtime, mutator);
g->dataFree();
return searchEngine.run(g);
}
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
vector<int> rules) {
auto runtime = as<CudaRuntimeObj>(_runtime);

View File

@ -146,6 +146,7 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
// dbg(nnet::FullPrinterVisitor().print(candidate.root));
if (auto g = expressionToGraph(candidate.root, in_graph)) {
out_graphs.emplace_back(g);
hasTunedKernel = true;
}
// break; // HACK:Debug only for the first subgraph
}
@ -415,7 +416,7 @@ infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
auto input =
nameNToTensorT.at(op->getInputs().at(0)->getName());
auto output = nameNToTensorT.at(outputNameN);
if (input->size() != output->size())
if (input->size() != output->size())
return nullptr;
g->addOpWithOutputs<ReshapeObj>(input, output,
output->getDims());
@ -848,7 +849,6 @@ Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
auto output = (i + 1 == ops.size())
? inputGraph->getOutputs()[0]
: g->addTensor(ops[i]->getOutput()->getDims());
dbg(input->getDims(), output->getDims());
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
}
return g;

View File

@ -0,0 +1,39 @@
import onnx
from pyinfinitensor import backend as ft
from pyinfinitensor.onnx import OnnxStub
def load_onnx(runtime, filename: str) -> ft.Graph:
stub = OnnxStub.from_onnx(onnx.load(filename), runtime, False)
return stub.handler.getGraph()
def run_and_evaluate(runtime, g):
ft.initializeGraphTensors(g)
runtime.run(g, True)
print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}')
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 10)}')
def search_depth_exp():
runtime = ft.cuda_runtime()
graphs = [
(ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'),
(ft.getLongformer(runtime, 1), 'longformer.bs1'),
]
print("Figure 16")
for original_g, name in graphs:
print(f"=== Model {name}")
for i in range(1, 7):
g = ft.optimizeWithDepthConstraint(original_g, runtime, i)
ft.initializeGraphTensors(g)
print(f'{name} Depth = {i}: {runtime.getPerfTime(g, True, True, False)} ms')
def perf_test():
runtime = ft.cuda_runtime()
g = ft.getLongformer(runtime, 1)
run_and_evaluate(runtime, g)
if __name__ == "__main__":
search_depth_exp()

View File

@ -28,9 +28,9 @@ def load_onnx(runtime, filename: str) -> ft.Graph:
def run_and_evaluate(runtime, g):
ft.initializeGraphTensors(g)
runtime.run(g, True)
print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 100)}')
# print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
# print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}')
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 10)}')
def run_graph_get_output_as_torch_tensor(runtime, g):
@ -161,8 +161,45 @@ def search_depth_exp():
# print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}')
# save_onnx(g, f"opt_{name}_depth{i}.onnx")
print(f'{name} Depth = {i}: {runtime.getPerfTime(g, True, True, False)} ms')
def model_e2e_exp():
runtime = ft.cuda_runtime()
model_evaluation =[
(lambda : ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'),
(lambda : ft.getGANGraph(16, runtime, 5, 0), 'InfoGAN.bs16'),
(lambda : ft.getGANGraph(1, runtime, 5, 1), 'DCGAN.bs1'),
(lambda : ft.getGANGraph(16, runtime, 5, 1), 'DCGAN.bs16'),
(lambda : ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
(lambda : ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs16.onnx'), 'gcn.bs16'),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs1.onnx'), 'csrnet.bs1'),
(lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs16.onnx'), 'csrnet.bs16'),
(lambda : ft.getLongformer(runtime, 1), 'longformer.bs1'),
(lambda : ft.getLongformer(runtime, 16), 'longformer.bs16'),
]
print("Figure 12")
for graph_ctor, name in model_evaluation:
print(f"=== {name}")
original_g = graph_ctor()
g = ft.optimizeModel(original_g, runtime, name)
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
# [3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90]) # Convtranspose2gemm
# save_onnx(g, f"opt_{name}.onnx")
run_and_evaluate(runtime, g)
def perf_test():
# wrong time 26.6 ms
# correct time 15 ms
runtime = ft.cuda_runtime()
g = ft.getLongformer(runtime, 1)
run_and_evaluate(runtime, g)
if __name__ == "__main__":
# perf_test()
model_e2e_exp()
exit()
runtime = ft.cuda_runtime()
graphs = [
# (construct_conv(runtime, 16, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1'), # FSRCNN Conv_2 1x1
@ -188,22 +225,6 @@ if __name__ == "__main__":
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs16.onnx'), 'resnet18.bs16'),
# (ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'),
]
# model_evaluation =[
# (ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'),
# (ft.getGANGraph(16, runtime, 5, 0), 'InfoGAN.bs16'),
# (ft.getGANGraph(1, runtime, 5, 1), 'DCGAN.bs16'),
# (ft.getGANGraph(16, runtime, 5, 1), 'DCGAN.bs16'),
# (ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
# (ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs16.onnx'), 'gcn.bs16'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs1.onnx'), 'resnet18.bs1'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs16.onnx'), 'resnet18.bs16'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs1.onnx'), 'csrnet.bs1'),
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs16.onnx'), 'csrnet.bs16'),
# (ft.getLongformer(runtime, 1), 'longformer.bs1'),
# (ft.getLongformer(runtime, 16), 'longformer.bs16'),
# ]
for original_g, name in graphs:

View File

@ -5,6 +5,7 @@ if [ "$#" == 0 ] || [ "$1" == "cuda" ]
then
echo "Load CUDA environment."
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0
spack load /2fcgebh # python3.7
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc
elif [ "$1" == "intelcpu" ]
then