forked from jiuyuan/InfiniTensor
feat: GraphObj 增加一个拓扑排序方法及其测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
f20e791cf5
commit
45a3cdfa30
|
@ -11,7 +11,7 @@ class GraphObj : public Object {
|
|||
OpVec ops;
|
||||
|
||||
public:
|
||||
explicit GraphObj(Runtime runtime) : runtime(runtime){};
|
||||
explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){};
|
||||
GraphObj(Runtime runtime, OpVec ops_in);
|
||||
string toString() const override;
|
||||
Runtime getRuntime() const { return runtime; }
|
||||
|
@ -27,6 +27,14 @@ class GraphObj : public Object {
|
|||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
|
||||
/**
|
||||
* Sort the nodes in topological order.
|
||||
* It returns true if the sorting is successful.
|
||||
* Otherwise false is returned, means that there are rings in the graph,
|
||||
* so the topological sorting fails.
|
||||
*/
|
||||
bool topo_sort();
|
||||
|
||||
void dataMalloc();
|
||||
|
||||
/**
|
||||
|
@ -76,6 +84,11 @@ class GraphObj : public Object {
|
|||
* @brief Add reverse connections and Op relationship in ctor.
|
||||
*/
|
||||
void addOperatorAndConnect(const Operator &op);
|
||||
|
||||
/**
|
||||
* @brief If the nodes is sorted in topological order.
|
||||
*/
|
||||
bool sorted;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#include "core/graph.h"
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
|
||||
namespace infini {
|
||||
|
||||
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
||||
GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
|
||||
: runtime(runtime), sorted(false) {
|
||||
map<UidBaseType, Tensor> tensorPool;
|
||||
// Clone tensors
|
||||
for (const auto &op : ops_in) {
|
||||
|
@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
|||
}
|
||||
|
||||
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||
sorted = false;
|
||||
ops.push_back(op);
|
||||
for (auto &input : op->getInputs()) {
|
||||
input->addInputOf(op);
|
||||
|
@ -66,6 +69,53 @@ string GraphObj::toString() const {
|
|||
return oss.str();
|
||||
}
|
||||
|
||||
bool GraphObj::topo_sort() {
|
||||
if (this->sorted)
|
||||
return true;
|
||||
|
||||
// std::unordered_set<Tensor> inputs;
|
||||
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||
std::vector<Operator> sorted;
|
||||
|
||||
while (!waiting.empty()) {
|
||||
// Any node is move to sorted in this loop.
|
||||
auto modified = false;
|
||||
// Find head nodes.
|
||||
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||
const auto &this_inputs = (*it)->getInputs();
|
||||
// If none of the input tensors is in waiting list,
|
||||
// this node is a head node.
|
||||
const auto is_head = std::all_of(
|
||||
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||
auto src = input->getOutputOf();
|
||||
return src // If the source node is in the waiting list,
|
||||
// means that this node is not the head node.
|
||||
? waiting.find(src) == waiting.end()
|
||||
// This tensor has no source node,
|
||||
// it must be a input tensor.
|
||||
: (/*inputs.insert(input),*/ true);
|
||||
});
|
||||
// Moves head node to sorted.
|
||||
if (is_head) {
|
||||
modified = true;
|
||||
sorted.emplace_back(std::move(*it));
|
||||
it = waiting.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
// Waiting list never modifies during a pass,
|
||||
// sorting fails.
|
||||
if (!modified) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Done.
|
||||
this->ops = std::move(sorted);
|
||||
return this->sorted = true;
|
||||
}
|
||||
|
||||
void GraphObj::dataMalloc() {
|
||||
for (auto &tensor : tensors) {
|
||||
tensor->dataMalloc();
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/blob.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/unary.h"
|
||||
#include "test.h"
|
||||
|
@ -36,6 +37,45 @@ TEST(Graph, build_and_run) {
|
|||
EXPECT_TRUE(o0->equalData(ans));
|
||||
}
|
||||
|
||||
TEST(Graph, topological) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
|
||||
auto ops = std::vector{
|
||||
g->addOpWithOutputs<AddObj>(abcd, e, abcde),
|
||||
g->addOpWithOutputs<AddObj>(abc, d, abcd),
|
||||
g->addOpWithOutputs<AddObj>(ab, c, abc),
|
||||
g->addOpWithOutputs<AddObj>(a, b, ab),
|
||||
};
|
||||
|
||||
{
|
||||
auto p = ops.begin();
|
||||
auto q = g->getOperators().begin();
|
||||
while (p != ops.end()) {
|
||||
EXPECT_EQ(*p++, *q++);
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_TRUE(g->topo_sort());
|
||||
|
||||
{
|
||||
auto p = ops.rbegin();
|
||||
auto q = g->getOperators().begin();
|
||||
while (p != ops.rend()) {
|
||||
EXPECT_EQ(*p++, *q++);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
||||
TEST(Graph, perf_engine) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
|
Loading…
Reference in New Issue