Chore: simplify type names

This commit is contained in:
Liyan Zheng 2023-04-15 13:58:18 +08:00
parent 09293730ea
commit a6b8f344d4
2 changed files with 43 additions and 45 deletions

View File

@ -4,8 +4,6 @@
#include "graph.h"
#include "mutator.h"
#include <unordered_map>
namespace infini {
class SearchEngine {
private:
@ -14,13 +12,7 @@ class SearchEngine {
std::function<bool(const Graph &, const Graph &)> graphTimeComparer;
public:
SearchEngine(Runtime runtime, Ref<Mutator> mutator)
: runtimeExec(runtime), mutator(mutator) {
// Compare graph with estimated time
graphTimeComparer = [this](const Graph &a, const Graph &b) -> bool {
return getEstimatedGraphPerf(a) < getEstimatedGraphPerf(b);
};
}
SearchEngine(Runtime runtime, Ref<Mutator> mutator);
~SearchEngine() {}
private: // Configurations
@ -34,11 +26,11 @@ class SearchEngine {
GroupEdge() = delete;
};
struct Candidate { // a graph with perf
std::shared_ptr<Graph> graph;
double perf = INFINITY;
};
struct MetaGraph { // a graph of subgraphs, for searching.
// struct Candidate { // a graph with perf
// Graph graph;
// double perf = INFINITY;
// };
struct MetaGraphObj { // a graph of subgraphs, for searching.
struct Node {
Graph graph;
std::vector<int> suc;
@ -47,35 +39,32 @@ class SearchEngine {
};
std::vector<Node> nodes;
};
using MetaGraph = Ref<MetaGraphObj>;
Graph run(const Graph graph); // entrance to search engine.
std::vector<Graph> search(const Graph &graph); // search for a partition.
private:
std::vector<Graph> partitionGraph(const Graph graph);
std::shared_ptr<MetaGraph> buildMetaGraphWithGraph(const Graph graph);
std::shared_ptr<MetaGraph>
buildMetaGraphWithPlan(const std::shared_ptr<MetaGraph> metaGraph,
const std::vector<int> &plan);
MetaGraph buildMetaGraphWithGraph(const Graph graph);
MetaGraph buildMetaGraphWithPlan(const MetaGraph metaGraph,
const std::vector<int> &plan);
// search horizontal merges
std::vector<std::shared_ptr<MetaGraph>>
searchMerge(std::shared_ptr<MetaGraph> &metaGraph);
void searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
std::vector<int> &plan, std::vector<int> &frontier,
std::vector<MetaGraph> searchMerge(MetaGraph &metaGraph);
void searchMergeDfs(MetaGraph &metaGraph, std::vector<int> &plan,
std::vector<int> &frontier,
std::vector<std::vector<int>> &plans,
std::unordered_set<uint64_t> &planSet);
std::vector<Graph>
searchMutation(const std::shared_ptr<MetaGraph> &metaGraph);
std::vector<Graph> searchMutation(const MetaGraph &metaGraph);
void printMetaGraph(Ref<SearchEngine::MetaGraph> metaGraph);
void printMetaGraph(MetaGraph metaGraph);
/**
* @brief Check whether a multi-brach graph can be merged into a single
* branch.
*/
bool isMultiBranchMergable(const Graph graph);
double getEstimatedGraphPerf(Graph graph) {
return runtimeExec->getPerfTime(graph, false, true);
}
double getEstimatedGraphPerf(Graph graph);
};
} // namespace infini

View File

@ -8,7 +8,17 @@
namespace infini {
void SearchEngine::printMetaGraph(Ref<SearchEngine::MetaGraph> metaGraph) {
using MetaGraph = SearchEngine::MetaGraph;
SearchEngine::SearchEngine(Runtime runtime, Ref<Mutator> mutator)
: runtimeExec(runtime), mutator(mutator) {
// Compare graph with estimated time
graphTimeComparer = [this](const Graph &a, const Graph &b) -> bool {
return getEstimatedGraphPerf(a) < getEstimatedGraphPerf(b);
};
}
void SearchEngine::printMetaGraph(MetaGraph metaGraph) {
for (size_t i = 0; i < metaGraph->nodes.size(); i++) {
auto &node = metaGraph->nodes[i];
std::cout << "id: " << i << std::endl;
@ -108,9 +118,8 @@ std::vector<Graph> SearchEngine::search(const Graph &graph) {
}
// Build metagraph with a graph, each operator is a node.
std::shared_ptr<SearchEngine::MetaGraph>
SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
auto metaGraph = std::make_shared<MetaGraph>();
MetaGraph SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
auto metaGraph = make_ref<MetaGraphObj>();
int numOps = graph->getOperators().size();
std::vector<int> cnt(numOps, 0);
@ -119,7 +128,7 @@ SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
std::vector<int> q(0);
for (size_t i = 0; i < graph->getOperators().size(); i++) {
auto &op = graph->getOperators()[i];
MetaGraph::Node node;
MetaGraphObj::Node node;
std::vector<Operator> ops;
ops.emplace_back(op);
node.graph = make_ref<GraphObj>(runtimeExec, ops);
@ -153,9 +162,8 @@ SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
// Build a metagraph with graph and a plan, a plan is which ops should be a
// node.
std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
const std::shared_ptr<SearchEngine::MetaGraph> metaGraph,
const std::vector<int> &plan) {
MetaGraph SearchEngine::buildMetaGraphWithPlan(const MetaGraph metaGraph,
const std::vector<int> &plan) {
int numGroups = 0;
for (auto i : plan) {
if (i > numGroups) {
@ -168,12 +176,12 @@ std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
groups[plan[i]].emplace_back(i);
}
auto resultMetaGraph = make_ref<MetaGraph>();
auto resultMetaGraph = make_ref<MetaGraphObj>();
for (auto &group : groups) {
std::vector<Operator> ops;
std::unordered_set<int> preSet, sucSet;
for (auto id : group) {
MetaGraph::Node node;
MetaGraphObj::Node node;
for (auto op : metaGraph->nodes[id].graph->getOperators()) {
ops.emplace_back(op);
}
@ -200,8 +208,7 @@ std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
}
// Search how to merge multiple ops.
std::vector<std::shared_ptr<SearchEngine::MetaGraph>>
SearchEngine::searchMerge(std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
vector<MetaGraph> SearchEngine::searchMerge(MetaGraph &metaGraph) {
IT_ASSERT(metaGraph != nullptr);
std::vector<int> plan(metaGraph->nodes.size());
for (size_t i = 0; i < plan.size(); i++) {
@ -218,7 +225,7 @@ SearchEngine::searchMerge(std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
std::unordered_set<HashType> planSet;
searchMergeDfs(metaGraph, plan, frontier, plans, planSet);
std::vector<std::shared_ptr<SearchEngine::MetaGraph>> metaGraphs;
vector<MetaGraph> metaGraphs;
for (auto &curPlan : plans) {
metaGraphs.emplace_back(buildMetaGraphWithPlan(metaGraph, curPlan));
}
@ -226,8 +233,7 @@ SearchEngine::searchMerge(std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
}
// DFS impl for search merge.
void SearchEngine::searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
std::vector<int> &plan,
void SearchEngine::searchMergeDfs(MetaGraph &metaGraph, std::vector<int> &plan,
std::vector<int> &frontier,
std::vector<std::vector<int>> &plans,
std::unordered_set<uint64_t> &planSet) {
@ -316,8 +322,7 @@ void SearchEngine::searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
}
// Search mutation for each compute op.
std::vector<Graph> SearchEngine::searchMutation(
const std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
std::vector<Graph> graphs = {nullptr};
// Append a node to all existing candidates
for (auto &node : metaGraph->nodes) {
@ -432,4 +437,8 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
return partitions;
}
double SearchEngine::getEstimatedGraphPerf(Graph graph) {
return runtimeExec->getPerfTime(graph, false, true);
}
} // namespace infini