diff --git a/include/core/graph.h b/include/core/graph.h index eaa6f4a5..8e317a8b 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -11,7 +11,7 @@ class GraphObj : public Object { OpVec ops; public: - explicit GraphObj(Runtime runtime) : runtime(runtime){}; + explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){}; GraphObj(Runtime runtime, OpVec ops_in); string toString() const override; Runtime getRuntime() const { return runtime; } @@ -27,6 +27,14 @@ class GraphObj : public Object { const OpVec &getOperators() const { return ops; } OpVec getComputeOps() const; + /** + * Sort the nodes in topological order. + * It returns true if the sorting is successful. + * Otherwise false is returned, means that there are rings in the graph, + * so the topological sorting fails. + */ + bool topo_sort(); + void dataMalloc(); /** @@ -76,6 +84,11 @@ class GraphObj : public Object { * @brief Add reverse connections and Op relationship in ctor. */ void addOperatorAndConnect(const Operator &op); + + /** + * @brief If the nodes is sorted in topological order. + */ + bool sorted; }; } // namespace infini diff --git a/src/core/graph.cc b/src/core/graph.cc index 8b2a6bbc..04ce2581 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,9 +1,11 @@ #include "core/graph.h" +#include #include namespace infini { -GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) { +GraphObj::GraphObj(Runtime runtime, OpVec ops_in) + : runtime(runtime), sorted(false) { map tensorPool; // Clone tensors for (const auto &op : ops_in) { @@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) { } void GraphObj::addOperatorAndConnect(const Operator &op) { + sorted = false; ops.push_back(op); for (auto &input : op->getInputs()) { input->addInputOf(op); @@ -66,6 +69,53 @@ string GraphObj::toString() const { return oss.str(); } +bool GraphObj::topo_sort() { + if (this->sorted) + return true; + + // std::unordered_set inputs; + std::unordered_set waiting(this->ops.begin(), this->ops.end()); + std::vector sorted; + + while (!waiting.empty()) { + // Any node is move to sorted in this loop. + auto modified = false; + // Find head nodes. + for (auto it = waiting.begin(); it != waiting.end();) { + const auto &this_inputs = (*it)->getInputs(); + // If none of the input tensors is in waiting list, + // this node is a head node. + const auto is_head = std::all_of( + this_inputs.begin(), this_inputs.end(), [&](const auto &input) { + auto src = input->getOutputOf(); + return src // If the source node is in the waiting list, + // means that this node is not the head node. + ? waiting.find(src) == waiting.end() + // This tensor has no source node, + // it must be a input tensor. + : (/*inputs.insert(input),*/ true); + }); + // Moves head node to sorted. + if (is_head) { + modified = true; + sorted.emplace_back(std::move(*it)); + it = waiting.erase(it); + } else { + ++it; + } + } + // Waiting list never modifies during a pass, + // sorting fails. + if (!modified) { + return false; + } + } + + // Done. + this->ops = std::move(sorted); + return this->sorted = true; +} + void GraphObj::dataMalloc() { for (auto &tensor : tensors) { tensor->dataMalloc(); diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index d208c21f..65bcf68a 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -1,6 +1,7 @@ #include "core/blob.h" #include "core/graph.h" #include "core/runtime.h" +#include "operators/element_wise.h" #include "operators/matmul.h" #include "operators/unary.h" #include "test.h" @@ -36,6 +37,45 @@ TEST(Graph, build_and_run) { EXPECT_TRUE(o0->equalData(ans)); } +TEST(Graph, topological) { + Runtime runtime = CpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32); + Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32); + + auto ops = std::vector{ + g->addOpWithOutputs(abcd, e, abcde), + g->addOpWithOutputs(abc, d, abcd), + g->addOpWithOutputs(ab, c, abc), + g->addOpWithOutputs(a, b, ab), + }; + + { + auto p = ops.begin(); + auto q = g->getOperators().begin(); + while (p != ops.end()) { + EXPECT_EQ(*p++, *q++); + } + } + + EXPECT_TRUE(g->topo_sort()); + + { + auto p = ops.rbegin(); + auto q = g->getOperators().begin(); + while (p != ops.rend()) { + EXPECT_EQ(*p++, *q++); + } + } +} // namespace infini + TEST(Graph, perf_engine) { Runtime runtime = CpuRuntimeObj::getInstance(); Graph g = make_ref(runtime);