feat: 支持 runCpu

fix: 正确实现获取折叠的输出
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-20 11:02:04 +08:00
parent 166e0c0a08
commit 054e8b4990
2 changed files with 19 additions and 6 deletions

@ -1 +1 @@
Subproject commit 96750a2a92e1e66a97a3d78374e612b0fcebfcd7
Subproject commit 2c272d4cea8124b11aebe62ce6cc0c0fd6b3090b

View File

@ -45,6 +45,13 @@ class Handler {
}
auto const &graph() const { return _g.internal(); }
void runCpu() {
using namespace infini;
auto rt = make_ref<NativeCpuRuntimeObj>();
_lastBackend = make_ref<GraphObj>(rt);
_outputs = _lastBackend->transformFromGraphTopo(_g, rt);
// _lastBackend->getRuntime()->run(_lastBackend);
}
void runCuda() {
using namespace infini;
#ifdef USE_CUDA
@ -54,15 +61,20 @@ class Handler {
_lastBackend->getRuntime()->run(_lastBackend);
#endif
}
template <class T> std::vector<T> copyout(size_t i) {
if (auto ptr = _outputs.at(i); ptr) {
return ptr->copyout<T>();
} else {
i += _g.internal().topology.globalOutputIndex();
auto tensor = _g.internal().edges[i].tensor;
auto ptr_ = reinterpret_cast<T *>(tensor->data->ptr);
return std::vector<T>(ptr_, ptr_ + tensor->elementsSize());
}
auto outputs = _g.internal().topology.globalOutputs();
ASSERT(i < outputs.size(), "Index out of range");
auto const &tensor = *_g.internal().edges[outputs[i]].tensor;
ASSERT(common::dataType<T>() == tensor.dataType, "Data type mismatch");
auto ptr = reinterpret_cast<T *>(tensor.data->ptr);
return std::vector<T>(ptr, ptr + tensor.elementsSize());
}
};
@ -232,6 +244,7 @@ void register_refactor(py::module &m) {
py::class_<Handler, std::shared_ptr<Handler>>(m, "Graph")
.def("fill_edge_info", &Handler::fillEdgeInfo)
.def("substitute", &Handler::substitute)
.def("run_cpu", &Handler::runCpu)
.def("run_cuda", &Handler::runCuda)
.def("copy_out_float", &Handler::copyout<float>)
.def("set_input", &Handler::setInput);