fix(graph.cc): fix topo_sort

This commit is contained in:
kilinchange 2023-12-15 10:13:18 +08:00
parent c859e655d3
commit 9d9e996713
1 changed files with 16 additions and 31 deletions

View File

@ -87,48 +87,33 @@ string GraphObj::toString() const {
}
bool GraphObj::topo_sort() {
if (this->sorted)
if (this->sorted) {
return true;
// std::unordered_set<Tensor> inputs;
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
}
std::vector<Operator> sorted;
while (!waiting.empty()) {
std::unordered_set<OperatorObj *> flags;
sorted.reserve(ops.size());
flags.reserve(ops.size());
while (sorted.size() < ops.size()) {
// Any node is move to sorted in this loop.
auto modified = false;
// Find head nodes.
for (auto it = waiting.begin(); it != waiting.end();) {
const auto &this_inputs = (*it)->getInputs();
// If none of the input tensors is in waiting list,
// this node is a head node.
const auto is_head = std::all_of(
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
auto src = input->getSource();
return src // If the source node is in the waiting
// list, means that this node is not the
// head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
: (/*inputs.insert(input),*/ true);
});
// Moves head node to sorted.
if (is_head) {
for (auto const &op : ops) {
if (auto const &inputs = op->getInputs();
flags.find(op.get()) == flags.end() &&
std::all_of(inputs.begin(), inputs.end(),
[&flags](auto const &input) {
auto ptr = input->getSource().get();
return !ptr || flags.find(ptr) != flags.end();
})) {
modified = true;
sorted.emplace_back(std::move(*it));
it = waiting.erase(it);
} else {
++it;
sorted.emplace_back(op);
flags.insert(op.get());
}
}
// Waiting list never modifies during a pass,
// sorting fails.
if (!modified) {
return false;
}
}
// Done.
this->ops = std::move(sorted);
return this->sorted = true;
}