feat: 支持值传入图

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-14 14:15:58 +08:00
parent 6830f02c88
commit 20a0c07992
2 changed files with 11 additions and 4 deletions

@ -1 +1 @@
Subproject commit c2aae62103d6c1ec3f1069be8b2ee6387928fc0e
Subproject commit 2790dc5ce195494870cb99ff04648663aac5a256

View File

@ -17,12 +17,17 @@ class Handler {
public:
explicit Handler(Graph &&g) : _g(std::forward<Graph>(g)) {}
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
void setInput(size_t index, std::shared_ptr<Tensor> tensor) {
ASSERT(_g.setInput(index, std::move(tensor)),
fmt::format("set input {} failed", index));
}
void substitute(const char *name, int64_t value) {
ASSERT(_g.substitute(name, value),
fmt::format("Variable {} not exist", name));
}
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
void runCuda() { TODO("Not implemented"); }
};
std::shared_ptr<Tensor> edge(int dataType, std::vector<DimExpr> shape,
@ -85,8 +90,10 @@ void register_refactor(py::module &m) {
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator");
py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor");
py::class_<Handler, std::shared_ptr<Handler>>(m, "Graph")
.def("fill_edge_info", &Handler::fillEdgeInfo)
.def("substitute", &Handler::substitute)
.def("fill_edge_info", &Handler::fillEdgeInfo);
.def("set_input", &Handler::setInput)
.def("run_cuda", &Handler::runCuda);
m.def("refactor_tensor", edge)
.def("refactor_operator", node)
.def("refactor_graph", graph);