#pragma once #include "common.h" #include "ref.h" #include #include #include namespace nnet { class ExprNode; class VarNode; class TensorNode; class OperatorNode; class RangeOpNode; class SubscriptNode; class BinaryOpNode; class ConstantNode; class FuncNode; using Expr = Ref; using Var = Ref; using Tensor = Ref; using Operator = Ref; using RangeOp = Ref; using Subscript = Ref; using BinaryOp = Ref; using Constant = Ref; using Func = Ref; class RoutineNode; using Routine = Ref; enum class RoutineType { NoneType = 100, MatmulNodeType, ConvNodeType, G2bmmNodeType, GbmmNodeType, ElementWiseNodeType // unmatchable }; constexpr inline int MatchableRoutineTypeCnt = 4; constexpr inline int RoutineTypeCnt = MatchableRoutineTypeCnt + 1; inline RoutineType idToRoutineType(int i) { return static_cast(i + 1 + static_cast(RoutineType::NoneType)); } inline int routineTypeToId(const RoutineType &routineType) { return static_cast(routineType) - static_cast(RoutineType::NoneType) - 1; } using VecExpr = vector; // common data structure using Iterator = Var; // RE: remove this alias template using PtrMap = std::map>; template // When keys are pointers, compare keys according to its value instead of // address Specially, the name of Var are compared due to the overload of op= // and hash. using PtrUmap = std::unordered_map, ptr_equal>; template using PtrUset = std::unordered_set, ptr_equal>; using Appearance = PtrMap>>; using StrideTable = PtrMap>>; // Tensor, dim, stride // AST node opeartor bool operator==(const Var &lhs, const string &rhs); bool operator==(const string &lhs, const Var &rhs); Expr operator+(const Expr &lhs, const Expr &rhs); BinaryOp operator-(const Expr &lhs, const Expr &rhs); BinaryOp operator*(const Expr &lhs, const Expr &rhs); BinaryOp operator/(const Expr &lhs, const Expr &rhs); BinaryOp operator%(const Expr &lhs, const Expr &rhs); Expr operator+(const Expr &lhs, const int &rhs); Expr operator+(const int &lhs, const Expr &rhs); Expr operator-(const Expr &lhs, const int &rhs); Expr operator-(const int &lhs, const Expr &rhs); Expr operator*(const Expr &lhs, const int &rhs); Expr operator*(const int &lhs, const Expr &rhs); Expr operator%(const Expr &lhs, const int rhs); Expr operator/(const Expr &lhs, const int rhs); string serializeVec(vector v); string serializeVec(vector v); template inline string serializeVec(vector v) { if (v.empty()) return "[]"; return "[" + std::accumulate( v.begin() + 1, v.end(), to_string(v[0]), [](const string &a, int b) { return a + ',' + to_string(b); }) + "]"; } // For RTTI and visitor pattern enum class NodeType { ConstantNodeType, BinaryOpNodeType, RangeOpNodeType, SubscriptNodeType, TensorNodeType, VarNodeType, FuncNodeType }; enum class FuncType { Relu, Tanh, PRelu }; #define DEFINE_GETTYPE(CLASS, isScalar_v) \ NodeType getType() const override { return NodeType::CLASS##Type; } \ bool isScalar() const override { return isScalar_v; } class ExprNode { public: virtual ~ExprNode() {} ExprNode &operator=(const ExprNode &rhs) = delete; virtual HashType hash() const = 0; // RE: remove? virtual string toReadable() const = 0; friend std::ostream &operator<<(std::ostream &ios, const ExprNode &expr); virtual NodeType getType() const = 0; virtual bool isScalar() const = 0; }; class VarNode : public ExprNode { std::string name; public: VarNode(std::string _name) : name(_name){}; virtual ~VarNode() {} DEFINE_GETTYPE(VarNode, true); const std::string &getName() const { return name; } HashType hash() const override { return genhash(name); }; string toReadable() const override { return name; }; bool equal(const Var &rhs) const { return name == rhs->getName(); } bool neq(const Var &rhs) const { return !equal(rhs); } bool less(const Var &rhs) const { return name < rhs->getName(); } bool equal(const string &rhs) const { return name == rhs; } bool operator==(const VarNode &rhs) const { return name == rhs.getName(); } bool operator<(const VarNode &rhs) const { return name < rhs.getName(); } }; enum class TensorType { Input, Weight, Intermediate }; class TensorNode : public ExprNode { string name; vector shape, paddings; TensorType type; Routine source; // if NO source, then this is a input/weight tensor public: TensorNode(string _name, vector _shape, vector _paddings = {}, Routine _source = nullptr); virtual ~TensorNode() {} DEFINE_GETTYPE(TensorNode, false); bool operator==(const string &rhs) { return name == rhs; } friend bool operator==(const string &lhs, const TensorNode &rhs) { return lhs == rhs.name; } HashType hash() const override { return genhash(name); } string toReadable() const override; string toOutputShape() const; const std::string &getName() const { return name; } std::vector &getPadding() { return paddings; } int getPadding(int i) const { return paddings[i]; } const vector &getPaddings() const { return paddings; } void setPadding(int i, int p) { paddings[i] = p; } const vector &getShape() const { return shape; } int getShape(int i) const { return shape[i]; } int64_t getSize() const; int getDims() const { return shape.size(); } const Routine &getSource() const { return source; } int getData(const Ref> &data, const vector &idx); size_t getOffset(const vector &idx); bool hasPadding(); }; enum class OpType { Range, Add, Mul, Div, Mod, Sub }; const char opSymbols[] = "#+*/%-"; class OperatorNode : public ExprNode { protected: const OpType opType; VecExpr subExprs; public: OperatorNode(OpType _opType) : opType(_opType){}; OperatorNode(OpType _opType, VecExpr _subExprs) : opType(_opType), subExprs(_subExprs){}; int getSubExprsNum() { return subExprs.size(); }; const VecExpr &getSubExprs() { return subExprs; } const Expr &getSubExprs(int i) const { return subExprs[i]; } OpType getOpType() const { return opType; }; void setOperands(int i, Expr e) { subExprs[i] = e; } }; using Range = pair; using VarRangePair = pair; inline int getLength(const Range &range) { return range.second - range.first; } struct IterationType { enum { Loop, Sum }; constexpr static int NumIterationType = 2; }; class RangeOpNode : public OperatorNode { public: enum { Summand, END_POS }; constexpr static int Loop = IterationType::Loop; constexpr static int Sum = IterationType::Sum; private: vector vars[IterationType::NumIterationType]; vector paddings; public: RangeOpNode(Expr _summand) : OperatorNode(OpType::Range, {_summand}){}; RangeOpNode(const vector &_loopIters, const vector &_sumIters, Expr _summand, const vector &paddings) : OperatorNode(OpType::Range, {_summand}), vars{_loopIters, _sumIters}, paddings(paddings){}; DEFINE_GETTYPE(RangeOpNode, false); virtual HashType hash() const override { nnet_unimplemented_halt(); return 0; }; string toReadable() const override; const Expr &getSummand() const { return subExprs[Summand]; } const vector &getVarRanges(int _index) const { return vars[_index]; } const vector &getLoopVarRanges() const { return vars[IterationType::Loop]; } const vector &getSumVarRanges() const { return vars[IterationType::Sum]; } int getNumOutputDims() const; bool hasVar(int index, Var name) const; bool hasLoopVar(Var name) const { return hasVar(Loop, name); } bool hasSumVar(Var name) const { return hasVar(Sum, name); } bool hasLoopVar(string name) const { return hasVar(Loop, make_ref(name)); } bool hasSumVar(string name) const { return hasVar(Sum, make_ref(name)); } int getVarIndex(int type, string name); void setSummand(Expr e) { subExprs[Summand] = e; } void setLoopIterator(const vector &vecExpr) { vars[Loop] = vecExpr; } void setSumIterator(const vector &vecExpr) { vars[Sum] = vecExpr; } void setIterator(const vector &loop, const vector &sum) { setLoopIterator(loop); setSumIterator(sum); } const VarRangePair &getVarRange(int _index, int i) const { return vars[_index][i]; } const Var &getLoopVar(int i) const { return vars[Loop][i].first; } Range getRange(const Var &var) const; VarRangePair getVarRange(const Var &var) const; bool hasPaddings() const; int getPaddings(int dim) const; vector getPaddings() const; void setPaddings(vector _paddings); void setVarRange(int _index, int i, VarRangePair pair) { vars[_index][i] = pair; } int64_t getFlops() const; int64_t getInputSize(const RangeOp &self) const; int64_t getOutputSize() const; vector getOutputShape() const; // Including paddings vector getOutputRanges() const; }; class BinaryOpNode : public OperatorNode { enum { LHS, RHS, END_POS }; public: BinaryOpNode(OpType _opType, Expr _lhs, Expr _rhs) : OperatorNode(_opType, {_lhs, _rhs}){}; virtual ~BinaryOpNode() {} DEFINE_GETTYPE(BinaryOpNode, true); virtual HashType hash() const override { return genhash((HashType)opType, genhash(subExprs[LHS]->hash(), subExprs[RHS]->hash())); }; virtual string toReadable() const override; const Expr &getLhs() const { return getSubExprs(LHS); }; const Expr &getRhs() const { return getSubExprs(RHS); }; void setLhs(Expr e) { setOperands(LHS, e); }; void setRhs(Expr e) { setOperands(RHS, e); }; // If Var/constant, use this one optional> getModDivParameter() const; // If (Var+constant)/constant, use this one pair getModDivExpr() const; bool isSwapable() const; }; class ConstantNode : public ExprNode { int val; public: ConstantNode(int _val) : val(_val){}; ConstantNode(const ConstantNode &rhs) : ExprNode(rhs), val(rhs.val){}; virtual ~ConstantNode() {} DEFINE_GETTYPE(ConstantNode, true); int getValue() const { return val; } virtual HashType hash() const override { return genhash(val, 6214587); }; virtual string toReadable() const override { string ret; ret += std::to_string(val); return ret; }; }; class SubscriptNode : public ExprNode { protected: Expr indexed; VecExpr subExprs; public: SubscriptNode(Expr _indexed, vector _subExprs) : subExprs(_subExprs) { setObject(_indexed); }; DEFINE_GETTYPE(SubscriptNode, true); virtual HashType hash() const override { nnet_unimplemented_continue(); return -1; }; virtual string toReadable() const override; size_t getDims() const { return subExprs.size(); } const VecExpr &getIndex() const { return subExprs; } const Expr &getIndex(size_t i) const { return subExprs[i]; } void setIndex(size_t i, Expr e) { subExprs[i] = e; } Expr *getObjectPtr() { return &indexed; } Expr getObject() const { return indexed; } void setObject(Expr e); bool isRangeOpSubscripted() const; bool isTensorSubscripted() const { return !isRangeOpSubscripted(); } // Get the ranges of objects including paddings vector getObjectRangesWithPaddings() const; vector getObjectRangesWithoutPaddings() const; }; class FuncNode : public ExprNode { protected: Expr object; FuncType funcType; public: FuncNode(Expr object, FuncType funcType) : object(object), funcType(funcType) { nnet_assert(object->isScalar(), "FuncNode operates on a scalar"); } DEFINE_GETTYPE(FuncNode, true); virtual HashType hash() const override { nnet_unimplemented_continue(); return -1; }; virtual string toReadable() const override; const Expr &getObject() const { return object; } void setObject(Expr e); FuncType getFuncType() const { return funcType; } }; // Wrappers for type deduction Subscript makeSubscript(const Expr &tensor, const VecExpr &subscripts); RangeOp makeRangeOperator(const vector &_loopIters, const vector &_sumIters, Expr _summand, const vector &paddings = {}); Tensor makeTensor(const string &name, const vector &shape, const vector &paddings = {}, const Routine &source = nullptr); // Pretty output for dbg with shared_ptr template > *_ = nullptr> std::ostream &operator<<(std::ostream &os, const Ref &a) { os << ((!a) ? string("nullptr") : a->toReadable()); return os; } #undef DEFINE_GETTYPE } // namespace nnet namespace std { template <> struct hash { size_t operator()(const nnet::VarNode &t) const { return std::hash()(t.getName()); } }; } // namespace std