diff --git a/include/core/mutator.h b/include/core/mutator.h index 32a01af2..a9f19349 100644 --- a/include/core/mutator.h +++ b/include/core/mutator.h @@ -16,6 +16,7 @@ class Mutator { Runtime runtime = NativeCpuRuntimeObj::getInstance()) : candidatesLimit(candidatesLimit), runtime(runtime){}; virtual ~Mutator(){}; + bool hasTunedKernel = false; virtual vector run(const Graph &in_graph) = 0; /** diff --git a/include/core/search_engine.h b/include/core/search_engine.h index 7caa9f8c..659d8e42 100644 --- a/include/core/search_engine.h +++ b/include/core/search_engine.h @@ -15,6 +15,7 @@ class SearchEngine { SearchEngine(Runtime runtime, Ref mutator); ~SearchEngine() {} int searchFilter = 0; + bool chooseBestMutation = true; private: // Configurations size_t partitionThreshold = diff --git a/include/nnet/test_models.h b/include/nnet/test_models.h index 1ecc5b82..c8b6ffe4 100644 --- a/include/nnet/test_models.h +++ b/include/nnet/test_models.h @@ -15,6 +15,7 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode, void initializeGraphTensors(Graph g, double l, double r, bool useInt); Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG); Graph optimizeWithDepthConstraint(Graph g, Runtime _runtime, int maxDepth); +Graph optimizeModel(Graph g, Runtime _runtime, string name); } // namespace infini diff --git a/src/core/search_engine.cc b/src/core/search_engine.cc index 34fa7d9c..f8c8f876 100644 --- a/src/core/search_engine.cc +++ b/src/core/search_engine.cc @@ -345,7 +345,8 @@ std::vector SearchEngine::searchMutation(const MetaGraph &metaGraph) { std::vector nextGraphs; if (node.type == 1) { // If it has computing OPs auto mutatedGraphs = mutator->run(node.graph); - constexpr bool chooseBestMutation = false; + if (mutator->hasTunedKernel) + chooseBestMutation = false; if (searchFilter == 1) { std::sort(mutatedGraphs.begin(), mutatedGraphs.end(), graphTimeComparer); @@ -358,12 +359,13 @@ std::vector SearchEngine::searchMutation(const MetaGraph &metaGraph) { if (mutatedGraphs.size() >= 10) mutatedGraphs.resize(10); mutatedGraphs = {mutatedGraphs[0]}; + } else { // avoid repeated kernel genreation + if (mutatedGraphs.size() >= 2) // INFOGAN + mutatedGraphs = {mutatedGraphs[1]}; + // if (mutatedGraphs.size() > 2) { + // mutatedGraphs.resize(2); + // } } - // // HACK: only try the first one for debug - // if (mutatedGraphs.size() > 2) - // mutatedGraphs.resize(2); - // if (mutatedGraphs.size() >= 2) - // mutatedGraphs = {mutatedGraphs[1]}; for (auto graph : graphs) { for (auto mutatedGraph : mutatedGraphs) { @@ -494,6 +496,14 @@ Graph SearchEngine::fuseVertically(const Graph &graph) { ops.emplace_back(op); continue; } + if (op->getOpType() == OpType::Relu || + op->getOpType() == OpType::PRelu) { + if (auto p = op->getInputs()[0]) + if (auto sop = p->getSource()) + if (sop->getOpType() == OpType::Conv || + sop->getOpType() == OpType::Matmul) + continue; + } vector chainOps; visitTime.emplace(op->getGuid(), ++cnt); diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 627bb021..b663edfd 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -198,10 +198,11 @@ Tensor TensorObj::clone(Runtime runtime) const { obj->freeData(); obj->targets.clear(); obj->source.reset(); - if (hasData()) { - obj->dataMalloc(); - obj->copyData(this); - } + // FIXME + // if (hasData()) { + // obj->dataMalloc(); + // obj->copyData(this); + // } return obj; } diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 0fa53d6f..7596b077 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -416,8 +416,8 @@ void export_test_model(py::module &m) { .def("initializeGraphTensors", &initializeGraphTensors, "g"_a, "l"_a = -0.1, "r"_a = 0.1, "useInt"_a = false) .def("convertNCHWtoNHWCModel", &convertNCHWtoNHWCModel) - .def("optimizeWithDepthConstraint", &optimizeWithDepthConstraint); - + .def("optimizeWithDepthConstraint", &optimizeWithDepthConstraint) + .def("optimizeModel", &optimizeModel); #endif } diff --git a/src/nnet/App/test_models.cc b/src/nnet/App/test_models.cc index 3d80d381..ab2ba3fd 100644 --- a/src/nnet/App/test_models.cc +++ b/src/nnet/App/test_models.cc @@ -27,6 +27,8 @@ using GANConfigs = vector>; using DetailedConfigs = vector>; +static const vector metaRules = {3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90}; + DetailedConfigs getGANConfigs(int id, int batch) { // The first conv can be transformed into gemm without reduction // n, f, h, w, c, r, s, stride, @@ -356,6 +358,18 @@ Graph convertNCHWtoNHWCModel(Runtime runtime, Graph inG) { return g; } +Graph optimizeModel(Graph g, Runtime _runtime, string name) { + auto runtime = as(_runtime); + Runtime cpu = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(cpu); + Ref mutator = + make_ref(NMutator::Mode::RuleBased, metaRules, runtime); + vector bestGraphs; + SearchEngine searchEngine(runtime, mutator); + g->dataFree(); + return searchEngine.run(g); +} + Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode, vector rules) { auto runtime = as(_runtime); diff --git a/src/nnet/nmutator.cc b/src/nnet/nmutator.cc index 0bc8be2d..b0ceb74d 100644 --- a/src/nnet/nmutator.cc +++ b/src/nnet/nmutator.cc @@ -146,6 +146,7 @@ void NMutator::runSingleOp(Graph in_graph, std::vector &out_graphs) { // dbg(nnet::FullPrinterVisitor().print(candidate.root)); if (auto g = expressionToGraph(candidate.root, in_graph)) { out_graphs.emplace_back(g); + hasTunedKernel = true; } // break; // HACK:Debug only for the first subgraph } @@ -415,7 +416,7 @@ infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) { auto input = nameNToTensorT.at(op->getInputs().at(0)->getName()); auto output = nameNToTensorT.at(outputNameN); - if (input->size() != output->size()) + if (input->size() != output->size()) return nullptr; g->addOpWithOutputs(input, output, output->getDims()); @@ -848,7 +849,6 @@ Graph NMutator::constructGraphByOperatorChain(vector ops, auto output = (i + 1 == ops.size()) ? inputGraph->getOutputs()[0] : g->addTensor(ops[i]->getOutput()->getDims()); - dbg(input->getDims(), output->getDims()); input = g->cloneOperator(ops[i], {input}, {output})->getOutput(); } return g; diff --git a/test/nnet/evaluate_max_depth.py b/test/nnet/evaluate_max_depth.py new file mode 100644 index 00000000..d796d621 --- /dev/null +++ b/test/nnet/evaluate_max_depth.py @@ -0,0 +1,39 @@ +import onnx +from pyinfinitensor import backend as ft +from pyinfinitensor.onnx import OnnxStub + + +def load_onnx(runtime, filename: str) -> ft.Graph: + stub = OnnxStub.from_onnx(onnx.load(filename), runtime, False) + return stub.handler.getGraph() + + +def run_and_evaluate(runtime, g): + ft.initializeGraphTensors(g) + runtime.run(g, True) + print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}') + print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}') + print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 10)}') + + +def search_depth_exp(): + runtime = ft.cuda_runtime() + graphs = [ + (ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'), + (ft.getLongformer(runtime, 1), 'longformer.bs1'), + ] + print("Figure 16") + for original_g, name in graphs: + print(f"=== Model {name}") + for i in range(1, 7): + g = ft.optimizeWithDepthConstraint(original_g, runtime, i) + ft.initializeGraphTensors(g) + print(f'{name} Depth = {i}: {runtime.getPerfTime(g, True, True, False)} ms') + +def perf_test(): + runtime = ft.cuda_runtime() + g = ft.getLongformer(runtime, 1) + run_and_evaluate(runtime, g) + +if __name__ == "__main__": + search_depth_exp() diff --git a/test/nnet/run_models_nnet.py b/test/nnet/run_models_nnet.py index 0e48aedf..fd01fecb 100644 --- a/test/nnet/run_models_nnet.py +++ b/test/nnet/run_models_nnet.py @@ -28,9 +28,9 @@ def load_onnx(runtime, filename: str) -> ft.Graph: def run_and_evaluate(runtime, g): ft.initializeGraphTensors(g) runtime.run(g, True) - print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}') - print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}') - print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 100)}') + # print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}') + # print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}') + print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 10)}') def run_graph_get_output_as_torch_tensor(runtime, g): @@ -161,8 +161,45 @@ def search_depth_exp(): # print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 10, 10)}') # save_onnx(g, f"opt_{name}_depth{i}.onnx") print(f'{name} Depth = {i}: {runtime.getPerfTime(g, True, True, False)} ms') + +def model_e2e_exp(): + runtime = ft.cuda_runtime() + model_evaluation =[ + (lambda : ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'), + (lambda : ft.getGANGraph(16, runtime, 5, 0), 'InfoGAN.bs16'), + (lambda : ft.getGANGraph(1, runtime, 5, 1), 'DCGAN.bs1'), + (lambda : ft.getGANGraph(16, runtime, 5, 1), 'DCGAN.bs16'), + (lambda : ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"), + (lambda : ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"), + (lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'), + (lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs16.onnx'), 'gcn.bs16'), + (lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs1.onnx'), 'csrnet.bs1'), + (lambda : load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs16.onnx'), 'csrnet.bs16'), + (lambda : ft.getLongformer(runtime, 1), 'longformer.bs1'), + (lambda : ft.getLongformer(runtime, 16), 'longformer.bs16'), + ] + print("Figure 12") + for graph_ctor, name in model_evaluation: + print(f"=== {name}") + original_g = graph_ctor() + g = ft.optimizeModel(original_g, runtime, name) + # g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased, + # [3, 2, 2, 2, 2, 5, 8, 8, 6, 91, 90]) # Convtranspose2gemm + # save_onnx(g, f"opt_{name}.onnx") + run_and_evaluate(runtime, g) + + +def perf_test(): + # wrong time 26.6 ms + # correct time 15 ms + runtime = ft.cuda_runtime() + g = ft.getLongformer(runtime, 1) + run_and_evaluate(runtime, g) if __name__ == "__main__": + # perf_test() + model_e2e_exp() + exit() runtime = ft.cuda_runtime() graphs = [ # (construct_conv(runtime, 16, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1'), # FSRCNN Conv_2 1x1 @@ -188,22 +225,6 @@ if __name__ == "__main__": # (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs16.onnx'), 'resnet18.bs16'), # (ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'), ] - # model_evaluation =[ - # (ft.getGANGraph(1, runtime, 5, 0), 'InfoGAN.bs1'), - # (ft.getGANGraph(16, runtime, 5, 0), 'InfoGAN.bs16'), - # (ft.getGANGraph(1, runtime, 5, 1), 'DCGAN.bs16'), - # (ft.getGANGraph(16, runtime, 5, 1), 'DCGAN.bs16'), - # (ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"), - # (ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"), - # (load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs1.onnx'), 'gcn.bs1'), - # (load_onnx(runtime, '/mnt/auxHome/models/einnet/gcn.bs16.onnx'), 'gcn.bs16'), - # (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs1.onnx'), 'resnet18.bs1'), - # (load_onnx(runtime, '/mnt/auxHome/models/einnet/resnet18.bs16.onnx'), 'resnet18.bs16'), - # (load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs1.onnx'), 'csrnet.bs1'), - # (load_onnx(runtime, '/mnt/auxHome/models/einnet/csrnet.bs16.onnx'), 'csrnet.bs16'), - # (ft.getLongformer(runtime, 1), 'longformer.bs1'), - # (ft.getLongformer(runtime, 16), 'longformer.bs16'), - # ] for original_g, name in graphs: diff --git a/test/script/env_lotus.sh b/test/script/env_lotus.sh index 9fe82b98..7b77dfd0 100644 --- a/test/script/env_lotus.sh +++ b/test/script/env_lotus.sh @@ -5,6 +5,7 @@ if [ "$#" == 0 ] || [ "$1" == "cuda" ] then echo "Load CUDA environment." spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 + spack load /2fcgebh # python3.7 export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc elif [ "$1" == "intelcpu" ] then