#pragma once #include "common.h" #include "derivator.h" #include "expr.h" #include "routine.h" #include #include namespace nnet { template class Functor; template class Functor { protected: int verbose; // FIXME: scope should be protected public: Functor(int _verobse = 0) : verbose(_verobse) {} virtual ~Functor() = default; #define DISPATCH(CLASS) \ case NodeType::CLASS##Type: \ return this->visit_(as(c), std::forward(args)...); \ break #define FUNCTOR_DEFAULT \ { return visitDefault(c, std::forward(args)...); } virtual R dispatch(const Expr &c, Args... args) { switch (c->getType()) { DISPATCH(ConstantNode); DISPATCH(BinaryOpNode); DISPATCH(RangeOpNode); DISPATCH(SubscriptNode); DISPATCH(TensorNode); DISPATCH(VarNode); DISPATCH(FuncNode); default: nnet_assert(0, "Unknown type"); return R(); } } virtual R visit_(const Constant &c, Args... args) FUNCTOR_DEFAULT; virtual R visit_(const BinaryOp &c, Args... args) FUNCTOR_DEFAULT; virtual R visit_(const RangeOp &c, Args... args) FUNCTOR_DEFAULT; virtual R visit_(const Subscript &c, Args... args) FUNCTOR_DEFAULT; virtual R visit_(const Var &c, Args... args) FUNCTOR_DEFAULT; virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT; virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT; virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) { dbg(*c); nnet_assert(0, "Reach unimplemented visit function."); return R(); }; [[deprecated("Define explicit methods for public access.")]] R operator()(const Expr &e, Args... args) { return dispatch(e, std::forward(args)...); } #undef FUNCTOR_DEFAULT #undef DISPATCH }; class Mutator : public Functor { public: Mutator(int _verobse = 0) : Functor(_verobse) {} Expr visit_(const Constant &c) override; Expr visit_(const BinaryOp &c) override; Expr visit_(const RangeOp &c) override; Expr visit_(const Subscript &c) override; Expr visit_(const Var &c) override; Expr visit_(const Tensor &c) override; Expr visit_(const Func &c) override; }; // template // class SingleStageVisitor : public Functor { // public: // SingleStageVisitor(int _verobse = 0) : Functor(_verobse) {} // // R visit(const Constant &c) override ; // R visit_(const BinaryOp &c) override { // if (verbose) // dbg(*c); // this->dispatch(c->getLhs()); // this->dispatch(c->getRhs()); // } // R visit_(const RangeOp &c) override { // if (verbose) // dbg(*c); // this->dispatch(ret->getSummand()); // // NOT visit iterators and its ranges // } // R visit_(const Subscript &c) override { // if (verbose) // dbg(*c); // this->dispatch(ret->getObject()); // for (size_t i = 0; i < ret->getDims(); ++i) // this->dispatch(ret->getIndex(i)); // } // // R visit(const Var &c) override; // // R visit(const Tensor &c) override; // }; // } // namespace nnet // #include "nnet/Visitor/ReplaceVariable.h" // #include "nnet/Visitor/StrideVisitor.h" // namespace nnet { class ExprTreeVisitor : public Functor { private: bool inBinary, inRange, inSub, inTensor; public: ExprTreeVisitor(bool _inBinary = 1, bool _inRange = 1, bool _inSub = 1, bool _inTensor = 1, int _verobse = 0) : Functor(_verobse), inBinary(_inBinary), inRange(_inRange), inSub(_inSub), inTensor(_inTensor) {} void visit_(const Constant &c) override; void visit_(const BinaryOp &c) override; void visit_(const RangeOp &c) override; void visit_(const Subscript &c) override; void visit_(const Var &c) override; void visit_(const Tensor &c) override; void visit_(const Func &c) override; }; } // namespace nnet