forked from jiuyuan/InfiniTensor
fix: 前端需要载入所有连入图拓扑的边
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
eda206e2d0
commit
229993f081
|
@ -1 +1 @@
|
|||
Subproject commit 1da9f030692d6ac43a82a73778a14557793cd585
|
||||
Subproject commit c361f800793400a4bb6c425ebfa1ffa834e467be
|
|
@ -16,7 +16,7 @@ class Handler {
|
|||
Graph _g;
|
||||
|
||||
public:
|
||||
explicit Handler(Graph &&g) : _g(std::forward<Graph>(g)) {}
|
||||
explicit Handler(Graph g) : _g(std::move(g)) {}
|
||||
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
|
||||
void setInput(size_t index, std::shared_ptr<Tensor> tensor) {
|
||||
ASSERT(_g.setInput(index, std::move(tensor)),
|
||||
|
@ -63,18 +63,26 @@ graph(std::unordered_map<Name, std::pair<std::vector<Name>, std::vector<Name>>>
|
|||
std::vector<Name> inputs, std::vector<Name> outputs) {
|
||||
auto builder = graph_topo::Builder<Name, Node, Name, Edge>{};
|
||||
for (auto &[node, rels] : topology) {
|
||||
auto &[inputs, outputs] = rels;
|
||||
for (auto const &input : inputs) {
|
||||
builder.edges.insert({input, {nullptr, input}});
|
||||
}
|
||||
for (auto const &output : outputs) {
|
||||
builder.edges.insert({output, {nullptr, output}});
|
||||
}
|
||||
builder.topology.insert(
|
||||
{std::move(node), {std::move(rels.first), std::move(rels.second)}});
|
||||
{std::move(node), {std::move(inputs), std::move(outputs)}});
|
||||
}
|
||||
builder.nodes.reserve(nodes.size());
|
||||
builder.edges.reserve(edges.size());
|
||||
for (auto &[name, operator_] : nodes) {
|
||||
auto node = Node{std::move(operator_), name};
|
||||
builder.nodes.insert({std::move(name), std::move(node)});
|
||||
}
|
||||
for (auto &[name, tensor] : edges) {
|
||||
auto edge = Edge{std::move(tensor), name};
|
||||
builder.edges.insert({std::move(name), std::move(edge)});
|
||||
auto it = builder.edges.find(name);
|
||||
ASSERT(it != builder.edges.end(), "Edge not connected");
|
||||
it->second.tensor = std::move(edge.tensor);
|
||||
}
|
||||
builder.globalInputs = std::move(inputs);
|
||||
builder.globalOutputs = std::move(outputs);
|
||||
|
|
Loading…
Reference in New Issue