feat: GraphObj 增加一个拓扑排序方法及其测试

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-17 13:21:10 +08:00
parent f20e791cf5
commit 45a3cdfa30
3 changed files with 105 additions and 2 deletions

View File

@ -11,7 +11,7 @@ class GraphObj : public Object {
OpVec ops;
public:
explicit GraphObj(Runtime runtime) : runtime(runtime){};
explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){};
GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override;
Runtime getRuntime() const { return runtime; }
@ -27,6 +27,14 @@ class GraphObj : public Object {
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
/**
* Sort the nodes in topological order.
* It returns true if the sorting is successful.
* Otherwise false is returned, means that there are rings in the graph,
* so the topological sorting fails.
*/
bool topo_sort();
void dataMalloc();
/**
@ -76,6 +84,11 @@ class GraphObj : public Object {
* @brief Add reverse connections and Op relationship in ctor.
*/
void addOperatorAndConnect(const Operator &op);
/**
* @brief If the nodes is sorted in topological order.
*/
bool sorted;
};
} // namespace infini

View File

@ -1,9 +1,11 @@
#include "core/graph.h"
#include <algorithm>
#include <queue>
namespace infini {
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
: runtime(runtime), sorted(false) {
map<UidBaseType, Tensor> tensorPool;
// Clone tensors
for (const auto &op : ops_in) {
@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
}
void GraphObj::addOperatorAndConnect(const Operator &op) {
sorted = false;
ops.push_back(op);
for (auto &input : op->getInputs()) {
input->addInputOf(op);
@ -66,6 +69,53 @@ string GraphObj::toString() const {
return oss.str();
}
bool GraphObj::topo_sort() {
if (this->sorted)
return true;
// std::unordered_set<Tensor> inputs;
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
std::vector<Operator> sorted;
while (!waiting.empty()) {
// Any node is move to sorted in this loop.
auto modified = false;
// Find head nodes.
for (auto it = waiting.begin(); it != waiting.end();) {
const auto &this_inputs = (*it)->getInputs();
// If none of the input tensors is in waiting list,
// this node is a head node.
const auto is_head = std::all_of(
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
auto src = input->getOutputOf();
return src // If the source node is in the waiting list,
// means that this node is not the head node.
? waiting.find(src) == waiting.end()
// This tensor has no source node,
// it must be a input tensor.
: (/*inputs.insert(input),*/ true);
});
// Moves head node to sorted.
if (is_head) {
modified = true;
sorted.emplace_back(std::move(*it));
it = waiting.erase(it);
} else {
++it;
}
}
// Waiting list never modifies during a pass,
// sorting fails.
if (!modified) {
return false;
}
}
// Done.
this->ops = std::move(sorted);
return this->sorted = true;
}
void GraphObj::dataMalloc() {
for (auto &tensor : tensors) {
tensor->dataMalloc();

View File

@ -1,6 +1,7 @@
#include "core/blob.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/element_wise.h"
#include "operators/matmul.h"
#include "operators/unary.h"
#include "test.h"
@ -36,6 +37,45 @@ TEST(Graph, build_and_run) {
EXPECT_TRUE(o0->equalData(ans));
}
TEST(Graph, topological) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32);
auto ops = std::vector{
g->addOpWithOutputs<AddObj>(abcd, e, abcde),
g->addOpWithOutputs<AddObj>(abc, d, abcd),
g->addOpWithOutputs<AddObj>(ab, c, abc),
g->addOpWithOutputs<AddObj>(a, b, ab),
};
{
auto p = ops.begin();
auto q = g->getOperators().begin();
while (p != ops.end()) {
EXPECT_EQ(*p++, *q++);
}
}
EXPECT_TRUE(g->topo_sort());
{
auto p = ops.rbegin();
auto q = g->getOperators().begin();
while (p != ops.rend()) {
EXPECT_EQ(*p++, *q++);
}
}
} // namespace infini
TEST(Graph, perf_engine) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);