feat: 支持从前端取得折叠的输出

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-20 09:49:17 +08:00
parent 8086f3d34b
commit 166e0c0a08
3 changed files with 15 additions and 5 deletions

@ -1 +1 @@
Subproject commit 520f88d812136092735acc410dd399f5141b5be4
Subproject commit 96750a2a92e1e66a97a3d78374e612b0fcebfcd7

View File

@ -394,9 +394,12 @@ GraphObj::transformFromGraphTopo(refactor::computation::Graph &graph,
std::vector<Tensor> ans;
for (auto edgeIdx : it.globalOutputs()) {
auto t = edgeToTensor.at(edgeIdx);
t->setOutput();
ans.push_back(std::move(t));
if (auto it = edgeToTensor.find(edgeIdx); it != edgeToTensor.end()) {
it->second->setOutput();
ans.push_back(it->second);
} else {
ans.push_back(nullptr);
}
}
dataMalloc();

View File

@ -55,7 +55,14 @@ class Handler {
#endif
}
template <class T> std::vector<T> copyout(size_t i) {
return _outputs.at(i)->copyout<T>();
if (auto ptr = _outputs.at(i); ptr) {
return ptr->copyout<T>();
} else {
i += _g.internal().topology.globalOutputIndex();
auto tensor = _g.internal().edges[i].tensor;
auto ptr_ = reinterpret_cast<T *>(tensor->data->ptr);
return std::vector<T>(ptr_, ptr_ + tensor->elementsSize());
}
}
};