forked from jiuyuan/InfiniTensor
49 lines
1.7 KiB
C++
49 lines
1.7 KiB
C++
#pragma once
|
|
#include "nnet/Pass/Pass.h"
|
|
#include "nnet/ReplaceKit.h"
|
|
|
|
namespace nnet {
|
|
|
|
class Rule8GuidedDLT : public Pass {
|
|
public:
|
|
Rule8GuidedDLT(Derivator &derivator) : Pass(derivator, "Rule8GuidedDLT") {}
|
|
VecExpr guidedDLT(Formula &origin, int depth, Expr &rCur,
|
|
bool debug = false);
|
|
|
|
private:
|
|
virtual void transform(Formula &origin, int dfsDepth, Expr &rCur) override;
|
|
/**
|
|
* @brief If only one row miss match (more iterators mismatch), directly do
|
|
* data layout construction according to the IT.
|
|
*
|
|
* @return Expr Return nullptr if failed.
|
|
*/
|
|
Expr guidedDLTMoreVar2(const RangeOp &cur, const Mismatch &mismatch,
|
|
const IteratorTable &exprIT, const Pattern &pattern);
|
|
/**
|
|
* @brief Check whether two iterators overlap each other. If overlapping, we
|
|
* cannot simply reconstruct the tensor into a new one by seperate all
|
|
* iterators into different dimensions.
|
|
*/
|
|
bool checkElementsHaveOnlyOneAccessIteratorSet(const IteratorTable &exprIT,
|
|
int tensorID);
|
|
/**
|
|
* @brief Only product of two tensors can be guided DLTed.
|
|
*
|
|
* @param cur
|
|
* @return true
|
|
* @return false
|
|
*/
|
|
bool statisfyGuidedDLT(RangeOp cur) const;
|
|
/**
|
|
* @brief Deal with output DLT mismatch only.
|
|
*/
|
|
Expr guidedDLTDLMismatch(const RangeOp &cur, const Mismatch &mismatch,
|
|
const IteratorTable &exprIT,
|
|
const Pattern &pattern);
|
|
Expr buildGuidedDLTSource(const Subscript &originalSub, Replace replace,
|
|
vector<Var> tensorDimAxes, vector<int> newShape);
|
|
};
|
|
|
|
} // namespace nnet
|