From 9d9e996713e0fc9ff052b36168bbc5df36583b3e Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 15 Dec 2023 10:13:18 +0800 Subject: [PATCH] fix(graph.cc): fix topo_sort --- src/core/graph.cc | 47 ++++++++++++++++------------------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/src/core/graph.cc b/src/core/graph.cc index 5eb67402..19831d65 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -87,48 +87,33 @@ string GraphObj::toString() const { } bool GraphObj::topo_sort() { - if (this->sorted) + if (this->sorted) { return true; - - // std::unordered_set inputs; - std::unordered_set waiting(this->ops.begin(), this->ops.end()); + } std::vector sorted; - - while (!waiting.empty()) { + std::unordered_set flags; + sorted.reserve(ops.size()); + flags.reserve(ops.size()); + while (sorted.size() < ops.size()) { // Any node is move to sorted in this loop. auto modified = false; - // Find head nodes. - for (auto it = waiting.begin(); it != waiting.end();) { - const auto &this_inputs = (*it)->getInputs(); - // If none of the input tensors is in waiting list, - // this node is a head node. - const auto is_head = std::all_of( - this_inputs.begin(), this_inputs.end(), [&](const auto &input) { - auto src = input->getSource(); - return src // If the source node is in the waiting - // list, means that this node is not the - // head node. - ? waiting.find(src) == waiting.end() - // This tensor has no source node, - // it must be a input tensor. - : (/*inputs.insert(input),*/ true); - }); - // Moves head node to sorted. - if (is_head) { + for (auto const &op : ops) { + if (auto const &inputs = op->getInputs(); + flags.find(op.get()) == flags.end() && + std::all_of(inputs.begin(), inputs.end(), + [&flags](auto const &input) { + auto ptr = input->getSource().get(); + return !ptr || flags.find(ptr) != flags.end(); + })) { modified = true; - sorted.emplace_back(std::move(*it)); - it = waiting.erase(it); - } else { - ++it; + sorted.emplace_back(op); + flags.insert(op.get()); } } - // Waiting list never modifies during a pass, - // sorting fails. if (!modified) { return false; } } - // Done. this->ops = std::move(sorted); return this->sorted = true; }