forked from jiuyuan/InfiniTensor
fix(graph.cc): fix topo_sort
This commit is contained in:
parent
c859e655d3
commit
9d9e996713
|
@ -87,48 +87,33 @@ string GraphObj::toString() const {
|
|||
}
|
||||
|
||||
bool GraphObj::topo_sort() {
|
||||
if (this->sorted)
|
||||
if (this->sorted) {
|
||||
return true;
|
||||
|
||||
// std::unordered_set<Tensor> inputs;
|
||||
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||
}
|
||||
std::vector<Operator> sorted;
|
||||
|
||||
while (!waiting.empty()) {
|
||||
std::unordered_set<OperatorObj *> flags;
|
||||
sorted.reserve(ops.size());
|
||||
flags.reserve(ops.size());
|
||||
while (sorted.size() < ops.size()) {
|
||||
// Any node is move to sorted in this loop.
|
||||
auto modified = false;
|
||||
// Find head nodes.
|
||||
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||
const auto &this_inputs = (*it)->getInputs();
|
||||
// If none of the input tensors is in waiting list,
|
||||
// this node is a head node.
|
||||
const auto is_head = std::all_of(
|
||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getSource();
|
||||
return src // If the source node is in the waiting
|
||||
// list, means that this node is not the
|
||||
// head node.
|
||||
? waiting.find(src) == waiting.end()
|
||||
// This tensor has no source node,
|
||||
// it must be a input tensor.
|
||||
: (/*inputs.insert(input),*/ true);
|
||||
});
|
||||
// Moves head node to sorted.
|
||||
if (is_head) {
|
||||
for (auto const &op : ops) {
|
||||
if (auto const &inputs = op->getInputs();
|
||||
flags.find(op.get()) == flags.end() &&
|
||||
std::all_of(inputs.begin(), inputs.end(),
|
||||
[&flags](auto const &input) {
|
||||
auto ptr = input->getSource().get();
|
||||
return !ptr || flags.find(ptr) != flags.end();
|
||||
})) {
|
||||
modified = true;
|
||||
sorted.emplace_back(std::move(*it));
|
||||
it = waiting.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
sorted.emplace_back(op);
|
||||
flags.insert(op.get());
|
||||
}
|
||||
}
|
||||
// Waiting list never modifies during a pass,
|
||||
// sorting fails.
|
||||
if (!modified) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Done.
|
||||
this->ops = std::move(sorted);
|
||||
return this->sorted = true;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue