forked from jiuyuan/InfiniTensor
feat: 支持 runCpu
fix: 正确实现获取折叠的输出 Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
166e0c0a08
commit
054e8b4990
|
@ -1 +1 @@
|
|||
Subproject commit 96750a2a92e1e66a97a3d78374e612b0fcebfcd7
|
||||
Subproject commit 2c272d4cea8124b11aebe62ce6cc0c0fd6b3090b
|
|
@ -45,6 +45,13 @@ class Handler {
|
|||
}
|
||||
auto const &graph() const { return _g.internal(); }
|
||||
|
||||
void runCpu() {
|
||||
using namespace infini;
|
||||
auto rt = make_ref<NativeCpuRuntimeObj>();
|
||||
_lastBackend = make_ref<GraphObj>(rt);
|
||||
_outputs = _lastBackend->transformFromGraphTopo(_g, rt);
|
||||
// _lastBackend->getRuntime()->run(_lastBackend);
|
||||
}
|
||||
void runCuda() {
|
||||
using namespace infini;
|
||||
#ifdef USE_CUDA
|
||||
|
@ -54,15 +61,20 @@ class Handler {
|
|||
_lastBackend->getRuntime()->run(_lastBackend);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T> std::vector<T> copyout(size_t i) {
|
||||
if (auto ptr = _outputs.at(i); ptr) {
|
||||
return ptr->copyout<T>();
|
||||
} else {
|
||||
i += _g.internal().topology.globalOutputIndex();
|
||||
auto tensor = _g.internal().edges[i].tensor;
|
||||
auto ptr_ = reinterpret_cast<T *>(tensor->data->ptr);
|
||||
return std::vector<T>(ptr_, ptr_ + tensor->elementsSize());
|
||||
}
|
||||
|
||||
auto outputs = _g.internal().topology.globalOutputs();
|
||||
ASSERT(i < outputs.size(), "Index out of range");
|
||||
|
||||
auto const &tensor = *_g.internal().edges[outputs[i]].tensor;
|
||||
ASSERT(common::dataType<T>() == tensor.dataType, "Data type mismatch");
|
||||
|
||||
auto ptr = reinterpret_cast<T *>(tensor.data->ptr);
|
||||
return std::vector<T>(ptr, ptr + tensor.elementsSize());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -232,6 +244,7 @@ void register_refactor(py::module &m) {
|
|||
py::class_<Handler, std::shared_ptr<Handler>>(m, "Graph")
|
||||
.def("fill_edge_info", &Handler::fillEdgeInfo)
|
||||
.def("substitute", &Handler::substitute)
|
||||
.def("run_cpu", &Handler::runCpu)
|
||||
.def("run_cuda", &Handler::runCuda)
|
||||
.def("copy_out_float", &Handler::copyout<float>)
|
||||
.def("set_input", &Handler::setInput);
|
||||
|
|
Loading…
Reference in New Issue