feat: 支持代入值到变量

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-12 13:14:05 +08:00
parent fd60190b1a
commit fd2d6b1926
2 changed files with 7 additions and 9 deletions

@ -1 +1 @@
Subproject commit de510b50216f7132d7d4149b259417208c0bbb68
Subproject commit fa7b32b18f25908fa61017583a3ad772b2ca549c

View File

@ -16,14 +16,11 @@ class Handler {
public:
explicit Handler(Graph &&g) : _g(std::forward<Graph>(g)) {}
std::vector<std::string> fillEdgeInfo() {
std::vector<std::string> ans;
auto variables = _g.fillEdgeInfo();
std::transform(variables.begin(), variables.end(),
std::back_inserter(ans),
[](auto &&v) { return std::move(v); });
return ans;
void substitute(const char *name, int64_t value) {
ASSERT(_g.substitute(name, value),
fmt::format("Variable {} not exist", name));
}
std::unordered_set<std::string> fillEdgeInfo() { return _g.fillEdgeInfo(); }
};
Edge edge(int dataType, std::vector<DimExpr> shape,
@ -76,10 +73,11 @@ void register_refactor(py::module &m) {
py::class_<DimExpr>(m, "DimExpr")
.def(py::init<int64_t>())
.def(py::init<std::string &&>());
.def(py::init<std::string>());
py::class_<Operator, Node>(m, "Node");
py::class_<Tensor, Edge>(m, "Edge");
py::class_<Handler, std::shared_ptr<Handler>>(m, "Graph")
.def("substitute", &Handler::substitute)
.def("fill_edge_info", &Handler::fillEdgeInfo);
m.def("refactor_tensor", edge)
.def("refactor_operator", node)