#pragma once #include "common.h" #include "expr.h" #include namespace nnet { // enum class DLTType { Split, Merge, Reorder }; struct DLTOperation { // DLTType type; virtual ~DLTOperation() {} }; struct DLTSplit : DLTOperation { int dim, factor; DLTSplit(int _dim, int _factor) : dim(_dim), factor(_factor) {} }; struct DLTMerge : DLTOperation { int dim0, dim1; DLTMerge(int _dim0, int _dim1) : dim0(_dim0), dim1(_dim1) {} }; struct DLTReorder : DLTOperation { vector dims; DLTReorder(vector _dims) : dims(_dims) {} }; class DLT { vector> ops; public: /** * @brief dim -> (dim/factor, factor) */ void split(int dim, int factor); /** * @brief Merge dim1 into dim0 -> (dim0, dim1) */ void merge(int dim0, int dim1); /** * @brief * * @param dims dims[new_dim]=old_dim */ void reorder(vector dims); optional apply(const RangeOp &rangeOp, const Subscript &subscript, string newTensorName); private: optional> splitIndex(Expr expr, int factor, RangeOp rangeOp); }; } // namespace nnet