forked from jiuyuan/InfiniTensor
feat: 支持值传入图
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
6830f02c88
commit
20a0c07992
|
@ -1 +1 @@
|
|||
Subproject commit c2aae62103d6c1ec3f1069be8b2ee6387928fc0e
|
||||
Subproject commit 2790dc5ce195494870cb99ff04648663aac5a256
|
|
@ -17,12 +17,17 @@ class Handler {
|
|||
|
||||
public:
|
||||
explicit Handler(Graph &&g) : _g(std::forward<Graph>(g)) {}
|
||||
|
||||
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
|
||||
void setInput(size_t index, std::shared_ptr<Tensor> tensor) {
|
||||
ASSERT(_g.setInput(index, std::move(tensor)),
|
||||
fmt::format("set input {} failed", index));
|
||||
}
|
||||
void substitute(const char *name, int64_t value) {
|
||||
ASSERT(_g.substitute(name, value),
|
||||
fmt::format("Variable {} not exist", name));
|
||||
}
|
||||
std::unordered_set<Name> fillEdgeInfo() { return _g.fillEdgeInfo(); }
|
||||
|
||||
void runCuda() { TODO("Not implemented"); }
|
||||
};
|
||||
|
||||
std::shared_ptr<Tensor> edge(int dataType, std::vector<DimExpr> shape,
|
||||
|
@ -85,8 +90,10 @@ void register_refactor(py::module &m) {
|
|||
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator");
|
||||
py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor");
|
||||
py::class_<Handler, std::shared_ptr<Handler>>(m, "Graph")
|
||||
.def("fill_edge_info", &Handler::fillEdgeInfo)
|
||||
.def("substitute", &Handler::substitute)
|
||||
.def("fill_edge_info", &Handler::fillEdgeInfo);
|
||||
.def("set_input", &Handler::setInput)
|
||||
.def("run_cuda", &Handler::runCuda);
|
||||
m.def("refactor_tensor", edge)
|
||||
.def("refactor_operator", node)
|
||||
.def("refactor_graph", graph);
|
||||
|
|
Loading…
Reference in New Issue