forked from jiuyuan/InfiniTensor
feat: 支持从前端取得折叠的输出
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
8086f3d34b
commit
166e0c0a08
|
@ -1 +1 @@
|
|||
Subproject commit 520f88d812136092735acc410dd399f5141b5be4
|
||||
Subproject commit 96750a2a92e1e66a97a3d78374e612b0fcebfcd7
|
|
@ -394,9 +394,12 @@ GraphObj::transformFromGraphTopo(refactor::computation::Graph &graph,
|
|||
|
||||
std::vector<Tensor> ans;
|
||||
for (auto edgeIdx : it.globalOutputs()) {
|
||||
auto t = edgeToTensor.at(edgeIdx);
|
||||
t->setOutput();
|
||||
ans.push_back(std::move(t));
|
||||
if (auto it = edgeToTensor.find(edgeIdx); it != edgeToTensor.end()) {
|
||||
it->second->setOutput();
|
||||
ans.push_back(it->second);
|
||||
} else {
|
||||
ans.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
dataMalloc();
|
||||
|
|
|
@ -55,7 +55,14 @@ class Handler {
|
|||
#endif
|
||||
}
|
||||
template <class T> std::vector<T> copyout(size_t i) {
|
||||
return _outputs.at(i)->copyout<T>();
|
||||
if (auto ptr = _outputs.at(i); ptr) {
|
||||
return ptr->copyout<T>();
|
||||
} else {
|
||||
i += _g.internal().topology.globalOutputIndex();
|
||||
auto tensor = _g.internal().edges[i].tensor;
|
||||
auto ptr_ = reinterpret_cast<T *>(tensor->data->ptr);
|
||||
return std::vector<T>(ptr_, ptr_ + tensor->elementsSize());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue