forked from jiuyuan/InfiniTensor
Chore: simplify type names
This commit is contained in:
parent
09293730ea
commit
a6b8f344d4
|
@ -4,8 +4,6 @@
|
|||
#include "graph.h"
|
||||
#include "mutator.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace infini {
|
||||
class SearchEngine {
|
||||
private:
|
||||
|
@ -14,13 +12,7 @@ class SearchEngine {
|
|||
std::function<bool(const Graph &, const Graph &)> graphTimeComparer;
|
||||
|
||||
public:
|
||||
SearchEngine(Runtime runtime, Ref<Mutator> mutator)
|
||||
: runtimeExec(runtime), mutator(mutator) {
|
||||
// Compare graph with estimated time
|
||||
graphTimeComparer = [this](const Graph &a, const Graph &b) -> bool {
|
||||
return getEstimatedGraphPerf(a) < getEstimatedGraphPerf(b);
|
||||
};
|
||||
}
|
||||
SearchEngine(Runtime runtime, Ref<Mutator> mutator);
|
||||
~SearchEngine() {}
|
||||
|
||||
private: // Configurations
|
||||
|
@ -34,11 +26,11 @@ class SearchEngine {
|
|||
GroupEdge() = delete;
|
||||
};
|
||||
|
||||
struct Candidate { // a graph with perf
|
||||
std::shared_ptr<Graph> graph;
|
||||
double perf = INFINITY;
|
||||
};
|
||||
struct MetaGraph { // a graph of subgraphs, for searching.
|
||||
// struct Candidate { // a graph with perf
|
||||
// Graph graph;
|
||||
// double perf = INFINITY;
|
||||
// };
|
||||
struct MetaGraphObj { // a graph of subgraphs, for searching.
|
||||
struct Node {
|
||||
Graph graph;
|
||||
std::vector<int> suc;
|
||||
|
@ -47,35 +39,32 @@ class SearchEngine {
|
|||
};
|
||||
std::vector<Node> nodes;
|
||||
};
|
||||
using MetaGraph = Ref<MetaGraphObj>;
|
||||
|
||||
Graph run(const Graph graph); // entrance to search engine.
|
||||
std::vector<Graph> search(const Graph &graph); // search for a partition.
|
||||
|
||||
private:
|
||||
std::vector<Graph> partitionGraph(const Graph graph);
|
||||
std::shared_ptr<MetaGraph> buildMetaGraphWithGraph(const Graph graph);
|
||||
std::shared_ptr<MetaGraph>
|
||||
buildMetaGraphWithPlan(const std::shared_ptr<MetaGraph> metaGraph,
|
||||
const std::vector<int> &plan);
|
||||
MetaGraph buildMetaGraphWithGraph(const Graph graph);
|
||||
MetaGraph buildMetaGraphWithPlan(const MetaGraph metaGraph,
|
||||
const std::vector<int> &plan);
|
||||
// search horizontal merges
|
||||
std::vector<std::shared_ptr<MetaGraph>>
|
||||
searchMerge(std::shared_ptr<MetaGraph> &metaGraph);
|
||||
void searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
|
||||
std::vector<int> &plan, std::vector<int> &frontier,
|
||||
std::vector<MetaGraph> searchMerge(MetaGraph &metaGraph);
|
||||
void searchMergeDfs(MetaGraph &metaGraph, std::vector<int> &plan,
|
||||
std::vector<int> &frontier,
|
||||
std::vector<std::vector<int>> &plans,
|
||||
std::unordered_set<uint64_t> &planSet);
|
||||
std::vector<Graph>
|
||||
searchMutation(const std::shared_ptr<MetaGraph> &metaGraph);
|
||||
std::vector<Graph> searchMutation(const MetaGraph &metaGraph);
|
||||
|
||||
void printMetaGraph(Ref<SearchEngine::MetaGraph> metaGraph);
|
||||
void printMetaGraph(MetaGraph metaGraph);
|
||||
/**
|
||||
* @brief Check whether a multi-brach graph can be merged into a single
|
||||
* branch.
|
||||
*/
|
||||
bool isMultiBranchMergable(const Graph graph);
|
||||
|
||||
double getEstimatedGraphPerf(Graph graph) {
|
||||
return runtimeExec->getPerfTime(graph, false, true);
|
||||
}
|
||||
double getEstimatedGraphPerf(Graph graph);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -8,7 +8,17 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
void SearchEngine::printMetaGraph(Ref<SearchEngine::MetaGraph> metaGraph) {
|
||||
using MetaGraph = SearchEngine::MetaGraph;
|
||||
|
||||
SearchEngine::SearchEngine(Runtime runtime, Ref<Mutator> mutator)
|
||||
: runtimeExec(runtime), mutator(mutator) {
|
||||
// Compare graph with estimated time
|
||||
graphTimeComparer = [this](const Graph &a, const Graph &b) -> bool {
|
||||
return getEstimatedGraphPerf(a) < getEstimatedGraphPerf(b);
|
||||
};
|
||||
}
|
||||
|
||||
void SearchEngine::printMetaGraph(MetaGraph metaGraph) {
|
||||
for (size_t i = 0; i < metaGraph->nodes.size(); i++) {
|
||||
auto &node = metaGraph->nodes[i];
|
||||
std::cout << "id: " << i << std::endl;
|
||||
|
@ -108,9 +118,8 @@ std::vector<Graph> SearchEngine::search(const Graph &graph) {
|
|||
}
|
||||
|
||||
// Build metagraph with a graph, each operator is a node.
|
||||
std::shared_ptr<SearchEngine::MetaGraph>
|
||||
SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
|
||||
auto metaGraph = std::make_shared<MetaGraph>();
|
||||
MetaGraph SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
|
||||
auto metaGraph = make_ref<MetaGraphObj>();
|
||||
|
||||
int numOps = graph->getOperators().size();
|
||||
std::vector<int> cnt(numOps, 0);
|
||||
|
@ -119,7 +128,7 @@ SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
|
|||
std::vector<int> q(0);
|
||||
for (size_t i = 0; i < graph->getOperators().size(); i++) {
|
||||
auto &op = graph->getOperators()[i];
|
||||
MetaGraph::Node node;
|
||||
MetaGraphObj::Node node;
|
||||
std::vector<Operator> ops;
|
||||
ops.emplace_back(op);
|
||||
node.graph = make_ref<GraphObj>(runtimeExec, ops);
|
||||
|
@ -153,9 +162,8 @@ SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
|
|||
|
||||
// Build a metagraph with graph and a plan, a plan is which ops should be a
|
||||
// node.
|
||||
std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
|
||||
const std::shared_ptr<SearchEngine::MetaGraph> metaGraph,
|
||||
const std::vector<int> &plan) {
|
||||
MetaGraph SearchEngine::buildMetaGraphWithPlan(const MetaGraph metaGraph,
|
||||
const std::vector<int> &plan) {
|
||||
int numGroups = 0;
|
||||
for (auto i : plan) {
|
||||
if (i > numGroups) {
|
||||
|
@ -168,12 +176,12 @@ std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
|
|||
groups[plan[i]].emplace_back(i);
|
||||
}
|
||||
|
||||
auto resultMetaGraph = make_ref<MetaGraph>();
|
||||
auto resultMetaGraph = make_ref<MetaGraphObj>();
|
||||
for (auto &group : groups) {
|
||||
std::vector<Operator> ops;
|
||||
std::unordered_set<int> preSet, sucSet;
|
||||
for (auto id : group) {
|
||||
MetaGraph::Node node;
|
||||
MetaGraphObj::Node node;
|
||||
for (auto op : metaGraph->nodes[id].graph->getOperators()) {
|
||||
ops.emplace_back(op);
|
||||
}
|
||||
|
@ -200,8 +208,7 @@ std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
|
|||
}
|
||||
|
||||
// Search how to merge multiple ops.
|
||||
std::vector<std::shared_ptr<SearchEngine::MetaGraph>>
|
||||
SearchEngine::searchMerge(std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
|
||||
vector<MetaGraph> SearchEngine::searchMerge(MetaGraph &metaGraph) {
|
||||
IT_ASSERT(metaGraph != nullptr);
|
||||
std::vector<int> plan(metaGraph->nodes.size());
|
||||
for (size_t i = 0; i < plan.size(); i++) {
|
||||
|
@ -218,7 +225,7 @@ SearchEngine::searchMerge(std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
|
|||
std::unordered_set<HashType> planSet;
|
||||
searchMergeDfs(metaGraph, plan, frontier, plans, planSet);
|
||||
|
||||
std::vector<std::shared_ptr<SearchEngine::MetaGraph>> metaGraphs;
|
||||
vector<MetaGraph> metaGraphs;
|
||||
for (auto &curPlan : plans) {
|
||||
metaGraphs.emplace_back(buildMetaGraphWithPlan(metaGraph, curPlan));
|
||||
}
|
||||
|
@ -226,8 +233,7 @@ SearchEngine::searchMerge(std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
|
|||
}
|
||||
|
||||
// DFS impl for search merge.
|
||||
void SearchEngine::searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
|
||||
std::vector<int> &plan,
|
||||
void SearchEngine::searchMergeDfs(MetaGraph &metaGraph, std::vector<int> &plan,
|
||||
std::vector<int> &frontier,
|
||||
std::vector<std::vector<int>> &plans,
|
||||
std::unordered_set<uint64_t> &planSet) {
|
||||
|
@ -316,8 +322,7 @@ void SearchEngine::searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
|
|||
}
|
||||
|
||||
// Search mutation for each compute op.
|
||||
std::vector<Graph> SearchEngine::searchMutation(
|
||||
const std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
|
||||
std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
|
||||
std::vector<Graph> graphs = {nullptr};
|
||||
// Append a node to all existing candidates
|
||||
for (auto &node : metaGraph->nodes) {
|
||||
|
@ -432,4 +437,8 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
|
|||
return partitions;
|
||||
}
|
||||
|
||||
double SearchEngine::getEstimatedGraphPerf(Graph graph) {
|
||||
return runtimeExec->getPerfTime(graph, false, true);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue