forked from jiuyuan/InfiniTensor
parent
04c2725312
commit
8e4e392a49
|
@ -52,6 +52,9 @@ class Handler {
|
|||
graph->getRuntime()->run(graph);
|
||||
#endif
|
||||
}
|
||||
template <class T> std::vector<T> copyout(size_t i) {
|
||||
return _outputs[i]->copyout<T>();
|
||||
}
|
||||
};
|
||||
|
||||
using TExport = std::tuple<Name, int, std::vector<std::variant<Name, int>>>;
|
||||
|
@ -220,6 +223,7 @@ void register_refactor(py::module &m) {
|
|||
.def("fill_edge_info", &Handler::fillEdgeInfo)
|
||||
.def("substitute", &Handler::substitute)
|
||||
.def("run_cuda", &Handler::runCuda)
|
||||
.def("copy_out_float", &Handler::copyout<float>)
|
||||
.def("set_input", &Handler::setInput);
|
||||
py::class_<NodeExport>(m, "NodeExport")
|
||||
.def(py::init<std::shared_ptr<Handler>>())
|
||||
|
|
Loading…
Reference in New Issue