#pragma once #include "nnet/Visitor/StrideVisitor.h" #include "nnet/visitor.h" namespace nnet { class AsTVMVisitor : public Functor { private: int nStage = 0, curStage = -1; std::unordered_map offset; std::vector inputs; std::string output; std::vector pythonVars; std::vector> inputShapes; std::vector outputShape; std::string stmts; public: std::string getStmts() const; const std::vector &getInputs() const { return inputs; } const std::string &getOutput() const { return output; } const std::vector> &getInputShapes() const { return inputShapes; } const std::vector &getOutputShape() const { return outputShape; } std::string visit_(const Constant &c) override; std::string visit_(const BinaryOp &c) override; std::string visit_(const Func &c) override; std::string visit_(const RangeOp &c) override; std::string visit_(const Subscript &c) override; std::string visit_(const Var &c) override; std::string visit_(const Tensor &c) override; }; } // namespace nnet