forked from jiuyuan/InfiniTensor
feat: 支持代入值到变量
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
fd60190b1a
commit
fd2d6b1926
|
@ -1 +1 @@
|
|||
Subproject commit de510b50216f7132d7d4149b259417208c0bbb68
|
||||
Subproject commit fa7b32b18f25908fa61017583a3ad772b2ca549c
|
|
@ -16,14 +16,11 @@ class Handler {
|
|||
public:
|
||||
explicit Handler(Graph &&g) : _g(std::forward<Graph>(g)) {}
|
||||
|
||||
std::vector<std::string> fillEdgeInfo() {
|
||||
std::vector<std::string> ans;
|
||||
auto variables = _g.fillEdgeInfo();
|
||||
std::transform(variables.begin(), variables.end(),
|
||||
std::back_inserter(ans),
|
||||
[](auto &&v) { return std::move(v); });
|
||||
return ans;
|
||||
void substitute(const char *name, int64_t value) {
|
||||
ASSERT(_g.substitute(name, value),
|
||||
fmt::format("Variable {} not exist", name));
|
||||
}
|
||||
std::unordered_set<std::string> fillEdgeInfo() { return _g.fillEdgeInfo(); }
|
||||
};
|
||||
|
||||
Edge edge(int dataType, std::vector<DimExpr> shape,
|
||||
|
@ -76,10 +73,11 @@ void register_refactor(py::module &m) {
|
|||
|
||||
py::class_<DimExpr>(m, "DimExpr")
|
||||
.def(py::init<int64_t>())
|
||||
.def(py::init<std::string &&>());
|
||||
.def(py::init<std::string>());
|
||||
py::class_<Operator, Node>(m, "Node");
|
||||
py::class_<Tensor, Edge>(m, "Edge");
|
||||
py::class_<Handler, std::shared_ptr<Handler>>(m, "Graph")
|
||||
.def("substitute", &Handler::substitute)
|
||||
.def("fill_edge_info", &Handler::fillEdgeInfo);
|
||||
m.def("refactor_tensor", edge)
|
||||
.def("refactor_operator", node)
|
||||
|
|
Loading…
Reference in New Issue