Add: RangeOpNode::getFullExpression()

This commit is contained in:
Liyan Zheng 2023-04-17 11:41:13 +08:00
parent b2c53458d9
commit 7a1d271c79
2 changed files with 9 additions and 1 deletions

View File

@ -206,7 +206,8 @@ struct IterationType {
enum { Loop, Sum };
constexpr static int NumIterationType = 2;
};
class RangeOpNode : public OperatorNode {
class RangeOpNode : public OperatorNode,
public std::enable_shared_from_this<RangeOpNode> {
public:
enum { Summand, END_POS };
constexpr static int Loop = IterationType::Loop;
@ -230,6 +231,7 @@ class RangeOpNode : public OperatorNode {
return 0;
};
string toReadable() const override;
string getFullExpression();
const Expr &getSummand() const { return subExprs[Summand]; }
const vector<VarRangePair> &getVarRanges(int _index) const {
return vars[_index];

View File

@ -1,4 +1,5 @@
#include "nnet/expr.h"
#include "nnet/Visitor/FullPrinterVisitor.h"
#include "nnet/Visitor/GetTensorsVisitor.h"
namespace nnet {
@ -463,4 +464,9 @@ void FuncNode::setObject(Expr e) {
object = e;
}
string RangeOpNode::getFullExpression() {
FullPrinterVisitor printer;
return printer.print(this->shared_from_this());
}
} // namespace nnet