forked from jiuyuan/InfiniTensor
feat: GraphObj 增加一个拓扑排序方法及其测试
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
f20e791cf5
commit
45a3cdfa30
|
@ -11,7 +11,7 @@ class GraphObj : public Object {
|
||||||
OpVec ops;
|
OpVec ops;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GraphObj(Runtime runtime) : runtime(runtime){};
|
explicit GraphObj(Runtime runtime) : runtime(runtime), sorted(false){};
|
||||||
GraphObj(Runtime runtime, OpVec ops_in);
|
GraphObj(Runtime runtime, OpVec ops_in);
|
||||||
string toString() const override;
|
string toString() const override;
|
||||||
Runtime getRuntime() const { return runtime; }
|
Runtime getRuntime() const { return runtime; }
|
||||||
|
@ -27,6 +27,14 @@ class GraphObj : public Object {
|
||||||
const OpVec &getOperators() const { return ops; }
|
const OpVec &getOperators() const { return ops; }
|
||||||
OpVec getComputeOps() const;
|
OpVec getComputeOps() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sort the nodes in topological order.
|
||||||
|
* It returns true if the sorting is successful.
|
||||||
|
* Otherwise false is returned, means that there are rings in the graph,
|
||||||
|
* so the topological sorting fails.
|
||||||
|
*/
|
||||||
|
bool topo_sort();
|
||||||
|
|
||||||
void dataMalloc();
|
void dataMalloc();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -76,6 +84,11 @@ class GraphObj : public Object {
|
||||||
* @brief Add reverse connections and Op relationship in ctor.
|
* @brief Add reverse connections and Op relationship in ctor.
|
||||||
*/
|
*/
|
||||||
void addOperatorAndConnect(const Operator &op);
|
void addOperatorAndConnect(const Operator &op);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief If the nodes is sorted in topological order.
|
||||||
|
*/
|
||||||
|
bool sorted;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
|
#include <algorithm>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
GraphObj::GraphObj(Runtime runtime, OpVec ops_in)
|
||||||
|
: runtime(runtime), sorted(false) {
|
||||||
map<UidBaseType, Tensor> tensorPool;
|
map<UidBaseType, Tensor> tensorPool;
|
||||||
// Clone tensors
|
// Clone tensors
|
||||||
for (const auto &op : ops_in) {
|
for (const auto &op : ops_in) {
|
||||||
|
@ -28,6 +30,7 @@ GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
void GraphObj::addOperatorAndConnect(const Operator &op) {
|
||||||
|
sorted = false;
|
||||||
ops.push_back(op);
|
ops.push_back(op);
|
||||||
for (auto &input : op->getInputs()) {
|
for (auto &input : op->getInputs()) {
|
||||||
input->addInputOf(op);
|
input->addInputOf(op);
|
||||||
|
@ -66,6 +69,53 @@ string GraphObj::toString() const {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GraphObj::topo_sort() {
|
||||||
|
if (this->sorted)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// std::unordered_set<Tensor> inputs;
|
||||||
|
std::unordered_set<Operator> waiting(this->ops.begin(), this->ops.end());
|
||||||
|
std::vector<Operator> sorted;
|
||||||
|
|
||||||
|
while (!waiting.empty()) {
|
||||||
|
// Any node is move to sorted in this loop.
|
||||||
|
auto modified = false;
|
||||||
|
// Find head nodes.
|
||||||
|
for (auto it = waiting.begin(); it != waiting.end();) {
|
||||||
|
const auto &this_inputs = (*it)->getInputs();
|
||||||
|
// If none of the input tensors is in waiting list,
|
||||||
|
// this node is a head node.
|
||||||
|
const auto is_head = std::all_of(
|
||||||
|
this_inputs.begin(), this_inputs.end(), [&](const auto &input) {
|
||||||
|
auto src = input->getOutputOf();
|
||||||
|
return src // If the source node is in the waiting list,
|
||||||
|
// means that this node is not the head node.
|
||||||
|
? waiting.find(src) == waiting.end()
|
||||||
|
// This tensor has no source node,
|
||||||
|
// it must be a input tensor.
|
||||||
|
: (/*inputs.insert(input),*/ true);
|
||||||
|
});
|
||||||
|
// Moves head node to sorted.
|
||||||
|
if (is_head) {
|
||||||
|
modified = true;
|
||||||
|
sorted.emplace_back(std::move(*it));
|
||||||
|
it = waiting.erase(it);
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Waiting list never modifies during a pass,
|
||||||
|
// sorting fails.
|
||||||
|
if (!modified) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done.
|
||||||
|
this->ops = std::move(sorted);
|
||||||
|
return this->sorted = true;
|
||||||
|
}
|
||||||
|
|
||||||
void GraphObj::dataMalloc() {
|
void GraphObj::dataMalloc() {
|
||||||
for (auto &tensor : tensors) {
|
for (auto &tensor : tensors) {
|
||||||
tensor->dataMalloc();
|
tensor->dataMalloc();
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#include "core/blob.h"
|
#include "core/blob.h"
|
||||||
#include "core/graph.h"
|
#include "core/graph.h"
|
||||||
#include "core/runtime.h"
|
#include "core/runtime.h"
|
||||||
|
#include "operators/element_wise.h"
|
||||||
#include "operators/matmul.h"
|
#include "operators/matmul.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include "test.h"
|
#include "test.h"
|
||||||
|
@ -36,6 +37,45 @@ TEST(Graph, build_and_run) {
|
||||||
EXPECT_TRUE(o0->equalData(ans));
|
EXPECT_TRUE(o0->equalData(ans));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Graph, topological) {
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor ab = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor c = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor abc = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor d = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor abcd = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor e = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
Tensor abcde = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||||
|
|
||||||
|
auto ops = std::vector{
|
||||||
|
g->addOpWithOutputs<AddObj>(abcd, e, abcde),
|
||||||
|
g->addOpWithOutputs<AddObj>(abc, d, abcd),
|
||||||
|
g->addOpWithOutputs<AddObj>(ab, c, abc),
|
||||||
|
g->addOpWithOutputs<AddObj>(a, b, ab),
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
auto p = ops.begin();
|
||||||
|
auto q = g->getOperators().begin();
|
||||||
|
while (p != ops.end()) {
|
||||||
|
EXPECT_EQ(*p++, *q++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_TRUE(g->topo_sort());
|
||||||
|
|
||||||
|
{
|
||||||
|
auto p = ops.rbegin();
|
||||||
|
auto q = g->getOperators().begin();
|
||||||
|
while (p != ops.rend()) {
|
||||||
|
EXPECT_EQ(*p++, *q++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
TEST(Graph, perf_engine) {
|
TEST(Graph, perf_engine) {
|
||||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
Graph g = make_ref<GraphObj>(runtime);
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
|
Loading…
Reference in New Issue